clip.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. """Minimal implementation of CLIPVisionModel intended to be only used
  2. within a vision language model."""
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import CLIPVisionConfig
  8. from transformers.models.clip.modeling_clip import CLIPAttention
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.sequence import SequenceData
  11. from aphrodite.inputs import LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.multimodal.image import (cached_get_tokenizer,
  16. repeat_and_pad_image_tokens)
  17. from aphrodite.quantization import QuantizationConfig
  18. def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  19. assert image_size % patch_size == 0
  20. return image_size // patch_size
  21. def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
  22. grid_length = get_clip_patch_grid_length(image_size=image_size,
  23. patch_size=patch_size)
  24. return grid_length * grid_length
  25. def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
  26. return get_clip_num_patches(image_size=hf_config.image_size,
  27. patch_size=hf_config.patch_size)
  28. def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
  29. return get_clip_image_feature_size(hf_config)
  30. def dummy_seq_data_for_clip(
  31. hf_config: CLIPVisionConfig,
  32. seq_len: int,
  33. *,
  34. image_token_id: int,
  35. image_feature_size_override: Optional[int] = None,
  36. ):
  37. if image_feature_size_override is None:
  38. image_feature_size = get_clip_image_feature_size(hf_config)
  39. else:
  40. image_feature_size = image_feature_size_override
  41. token_ids = [image_token_id] * image_feature_size
  42. token_ids += [0] * (seq_len - image_feature_size)
  43. return SequenceData(token_ids)
  44. def dummy_image_for_clip(
  45. hf_config: CLIPVisionConfig,
  46. *,
  47. image_width_override: Optional[int] = None,
  48. image_height_override: Optional[int] = None,
  49. ):
  50. width = height = hf_config.image_size
  51. if image_width_override is not None:
  52. width = image_width_override
  53. if image_height_override is not None:
  54. height = image_height_override
  55. image = Image.new("RGB", (width, height), color=0)
  56. return {"image": image}
  57. def input_processor_for_clip(
  58. model_config: ModelConfig,
  59. hf_config: CLIPVisionConfig,
  60. llm_inputs: LLMInputs,
  61. *,
  62. image_token_id: int,
  63. image_feature_size_override: Optional[int] = None,
  64. ):
  65. multi_modal_data = llm_inputs.get("multi_modal_data")
  66. if multi_modal_data is None or "image" not in multi_modal_data:
  67. return llm_inputs
  68. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  69. if image_feature_size_override is None:
  70. image_feature_size = get_clip_image_feature_size(hf_config)
  71. else:
  72. image_feature_size = image_feature_size_override
  73. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  74. tokenizer,
  75. llm_inputs.get("prompt"),
  76. llm_inputs["prompt_token_ids"],
  77. image_token_id=image_token_id,
  78. repeat_count=image_feature_size,
  79. )
  80. # NOTE: Create a defensive copy of the original inputs
  81. return LLMInputs(prompt_token_ids=new_token_ids,
  82. prompt=new_prompt,
  83. multi_modal_data=multi_modal_data)
  84. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
  85. class CLIPVisionEmbeddings(nn.Module):
  86. def __init__(self, config: CLIPVisionConfig):
  87. super().__init__()
  88. self.config = config
  89. self.embed_dim = config.hidden_size
  90. self.image_size = config.image_size
  91. self.patch_size = config.patch_size
  92. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  93. self.patch_embedding = nn.Conv2d(
  94. in_channels=config.num_channels,
  95. out_channels=self.embed_dim,
  96. kernel_size=self.patch_size,
  97. stride=self.patch_size,
  98. bias=False,
  99. )
  100. self.num_patches = get_clip_num_patches(image_size=self.image_size,
  101. patch_size=self.patch_size)
  102. self.num_positions = self.num_patches + 1
  103. self.position_embedding = nn.Embedding(self.num_positions,
  104. self.embed_dim)
  105. self.register_buffer("position_ids",
  106. torch.arange(self.num_positions).expand((1, -1)),
  107. persistent=False)
  108. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  109. batch_size = pixel_values.shape[0]
  110. target_dtype = self.patch_embedding.weight.dtype
  111. patch_embeds = self.patch_embedding(pixel_values.to(
  112. dtype=target_dtype)) # shape = [*, width, grid, grid]
  113. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  114. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  115. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  116. embeddings = embeddings + self.position_embedding(self.position_ids)
  117. return embeddings
  118. class CLIPMLP(nn.Module):
  119. def __init__(self,
  120. config: CLIPVisionConfig,
  121. quant_config: Optional[QuantizationConfig] = None):
  122. super().__init__()
  123. self.config = config
  124. self.activation_fn = get_act_fn(config.hidden_act)
  125. self.fc1 = ColumnParallelLinear(config.hidden_size,
  126. config.intermediate_size,
  127. bias=True,
  128. quant_config=quant_config)
  129. self.fc2 = RowParallelLinear(config.intermediate_size,
  130. config.hidden_size,
  131. bias=True,
  132. quant_config=quant_config)
  133. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  134. hidden_states, _ = self.fc1(hidden_states)
  135. hidden_states = self.activation_fn(hidden_states)
  136. hidden_states, _ = self.fc2(hidden_states)
  137. return hidden_states
  138. class CLIPEncoderLayer(nn.Module):
  139. def __init__(self,
  140. config: CLIPVisionConfig,
  141. quant_config: Optional[QuantizationConfig] = None):
  142. super().__init__()
  143. self.self_attn = CLIPAttention(config)
  144. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  145. eps=config.layer_norm_eps)
  146. self.mlp = CLIPMLP(config, quant_config=quant_config)
  147. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  148. eps=config.layer_norm_eps)
  149. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  150. residual = hidden_states
  151. hidden_states = self.layer_norm1(hidden_states)
  152. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  153. hidden_states = residual + hidden_states
  154. residual = hidden_states
  155. hidden_states = self.layer_norm2(hidden_states)
  156. hidden_states = self.mlp(hidden_states)
  157. hidden_states = residual + hidden_states
  158. return hidden_states
  159. class CLIPEncoder(nn.Module):
  160. """
  161. Transformer encoder consisting of `config.num_hidden_layers` self
  162. attention layers. Each layer is a [`CLIPEncoderLayer`].
  163. Args:
  164. config: CLIPConfig
  165. """
  166. def __init__(self,
  167. config: CLIPVisionConfig,
  168. quant_config: Optional[QuantizationConfig] = None,
  169. num_hidden_layers_override: Optional[int] = None):
  170. super().__init__()
  171. self.config = config
  172. if num_hidden_layers_override is None:
  173. num_hidden_layers = config.num_hidden_layers
  174. else:
  175. num_hidden_layers = num_hidden_layers_override
  176. self.layers = nn.ModuleList([
  177. CLIPEncoderLayer(config=config, quant_config=quant_config)
  178. for _ in range(num_hidden_layers)
  179. ])
  180. def forward(self, inputs_embeds: torch.Tensor):
  181. hidden_states = inputs_embeds
  182. for encoder_layer in self.layers:
  183. hidden_states = encoder_layer(hidden_states)
  184. return hidden_states
  185. class CLIPVisionTransformer(nn.Module):
  186. def __init__(self,
  187. config: CLIPVisionConfig,
  188. quant_config: Optional[QuantizationConfig] = None,
  189. num_hidden_layers_override: Optional[int] = None):
  190. super().__init__()
  191. self.config = config
  192. embed_dim = config.hidden_size
  193. self.embeddings = CLIPVisionEmbeddings(config)
  194. # NOTE: This typo of "layrnorm" is not fixed on purpose to match
  195. # the original transformers code and name of the model weights.
  196. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  197. self.encoder = CLIPEncoder(
  198. config=config,
  199. quant_config=quant_config,
  200. num_hidden_layers_override=num_hidden_layers_override)
  201. def forward(
  202. self,
  203. pixel_values: torch.Tensor,
  204. ) -> torch.Tensor:
  205. hidden_states = self.embeddings(pixel_values)
  206. hidden_states = self.pre_layrnorm(hidden_states)
  207. hidden_states = self.encoder(inputs_embeds=hidden_states)
  208. return hidden_states
  209. class CLIPVisionModel(nn.Module):
  210. config_class = CLIPVisionConfig
  211. main_input_name = "pixel_values"
  212. def __init__(self,
  213. config: CLIPVisionConfig,
  214. quant_config: Optional[QuantizationConfig] = None,
  215. num_hidden_layers_override: Optional[int] = None):
  216. super().__init__()
  217. self.vision_model = CLIPVisionTransformer(
  218. config=config,
  219. quant_config=quant_config,
  220. num_hidden_layers_override=num_hidden_layers_override)
  221. def forward(self, pixel_values: Optional[torch.Tensor] = None):
  222. return self.vision_model(pixel_values=pixel_values)
  223. @property
  224. def device(self):
  225. return next(self.parameters()).device