clip.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. """Minimal implementation of CLIPVisionModel intended to be only used
  2. within a vision language model."""
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import CLIPVisionConfig
  8. from transformers.models.clip.modeling_clip import CLIPAttention
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.sequence import SequenceData
  11. from aphrodite.inputs import LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.multimodal.image import (cached_get_tokenizer,
  16. repeat_and_pad_image_tokens)
  17. from aphrodite.quantization import QuantizationConfig
  18. def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  19. assert image_size % patch_size == 0
  20. return image_size // patch_size
  21. def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
  22. grid_length = get_clip_patch_grid_length(image_size=image_size,
  23. patch_size=patch_size)
  24. return grid_length * grid_length
  25. def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
  26. return get_clip_num_patches(image_size=hf_config.image_size,
  27. patch_size=hf_config.patch_size)
  28. def dummy_seq_data_for_clip(
  29. hf_config: CLIPVisionConfig,
  30. seq_len: int,
  31. *,
  32. image_token_id: int,
  33. image_feature_size_override: Optional[int] = None,
  34. ):
  35. if image_feature_size_override is None:
  36. image_feature_size = get_clip_image_feature_size(hf_config)
  37. else:
  38. image_feature_size = image_feature_size_override
  39. token_ids = [image_token_id] * image_feature_size
  40. token_ids += [0] * (seq_len - image_feature_size)
  41. return SequenceData(token_ids)
  42. def dummy_image_for_clip(
  43. hf_config: CLIPVisionConfig,
  44. *,
  45. image_width_override: Optional[int] = None,
  46. image_height_override: Optional[int] = None,
  47. ):
  48. width = height = hf_config.image_size
  49. if image_width_override is not None:
  50. width = image_width_override
  51. if image_height_override is not None:
  52. height = image_height_override
  53. image = Image.new("RGB", (width, height), color=0)
  54. return {"image": image}
  55. def input_processor_for_clip(
  56. model_config: ModelConfig,
  57. hf_config: CLIPVisionConfig,
  58. llm_inputs: LLMInputs,
  59. *,
  60. image_token_id: int,
  61. image_feature_size_override: Optional[int] = None,
  62. ):
  63. multi_modal_data = llm_inputs.get("multi_modal_data")
  64. if multi_modal_data is None or "image" not in multi_modal_data:
  65. return llm_inputs
  66. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  67. if image_feature_size_override is None:
  68. image_feature_size = get_clip_image_feature_size(hf_config)
  69. else:
  70. image_feature_size = image_feature_size_override
  71. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  72. tokenizer,
  73. llm_inputs.get("prompt"),
  74. llm_inputs["prompt_token_ids"],
  75. image_token_id=image_token_id,
  76. repeat_count=image_feature_size,
  77. )
  78. # NOTE: Create a defensive copy of the original inputs
  79. return LLMInputs(prompt_token_ids=new_token_ids,
  80. prompt=new_prompt,
  81. multi_modal_data=multi_modal_data)
  82. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
  83. class CLIPVisionEmbeddings(nn.Module):
  84. def __init__(self, config: CLIPVisionConfig):
  85. super().__init__()
  86. self.config = config
  87. self.embed_dim = config.hidden_size
  88. self.image_size = config.image_size
  89. self.patch_size = config.patch_size
  90. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  91. self.patch_embedding = nn.Conv2d(
  92. in_channels=config.num_channels,
  93. out_channels=self.embed_dim,
  94. kernel_size=self.patch_size,
  95. stride=self.patch_size,
  96. bias=False,
  97. )
  98. self.num_patches = get_clip_num_patches(image_size=self.image_size,
  99. patch_size=self.patch_size)
  100. self.num_positions = self.num_patches + 1
  101. self.position_embedding = nn.Embedding(self.num_positions,
  102. self.embed_dim)
  103. self.register_buffer("position_ids",
  104. torch.arange(self.num_positions).expand((1, -1)),
  105. persistent=False)
  106. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  107. batch_size = pixel_values.shape[0]
  108. target_dtype = self.patch_embedding.weight.dtype
  109. patch_embeds = self.patch_embedding(pixel_values.to(
  110. dtype=target_dtype)) # shape = [*, width, grid, grid]
  111. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  112. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  113. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  114. embeddings = embeddings + self.position_embedding(self.position_ids)
  115. return embeddings
  116. class CLIPMLP(nn.Module):
  117. def __init__(self,
  118. config: CLIPVisionConfig,
  119. quant_config: Optional[QuantizationConfig] = None):
  120. super().__init__()
  121. self.config = config
  122. self.activation_fn = get_act_fn(config.hidden_act)
  123. self.fc1 = ColumnParallelLinear(config.hidden_size,
  124. config.intermediate_size,
  125. bias=True,
  126. quant_config=quant_config)
  127. self.fc2 = RowParallelLinear(config.intermediate_size,
  128. config.hidden_size,
  129. bias=True,
  130. quant_config=quant_config)
  131. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  132. hidden_states, _ = self.fc1(hidden_states)
  133. hidden_states = self.activation_fn(hidden_states)
  134. hidden_states, _ = self.fc2(hidden_states)
  135. return hidden_states
  136. class CLIPEncoderLayer(nn.Module):
  137. def __init__(self,
  138. config: CLIPVisionConfig,
  139. quant_config: Optional[QuantizationConfig] = None):
  140. super().__init__()
  141. self.self_attn = CLIPAttention(config)
  142. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  143. eps=config.layer_norm_eps)
  144. self.mlp = CLIPMLP(config, quant_config=quant_config)
  145. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  146. eps=config.layer_norm_eps)
  147. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  148. residual = hidden_states
  149. hidden_states = self.layer_norm1(hidden_states)
  150. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  151. hidden_states = residual + hidden_states
  152. residual = hidden_states
  153. hidden_states = self.layer_norm2(hidden_states)
  154. hidden_states = self.mlp(hidden_states)
  155. hidden_states = residual + hidden_states
  156. return hidden_states
  157. class CLIPEncoder(nn.Module):
  158. """
  159. Transformer encoder consisting of `config.num_hidden_layers` self
  160. attention layers. Each layer is a [`CLIPEncoderLayer`].
  161. Args:
  162. config: CLIPConfig
  163. """
  164. def __init__(self,
  165. config: CLIPVisionConfig,
  166. quant_config: Optional[QuantizationConfig] = None):
  167. super().__init__()
  168. self.config = config
  169. self.layers = nn.ModuleList([
  170. CLIPEncoderLayer(config=config, quant_config=quant_config)
  171. for _ in range(config.num_hidden_layers)
  172. ])
  173. def forward(self,
  174. inputs_embeds: torch.Tensor,
  175. vision_feature_layer: int = -1):
  176. # Encoder forward pass only up to the required layer
  177. num_layer = len(self.layers) + vision_feature_layer + 1
  178. hidden_states = inputs_embeds
  179. for encoder_layer in self.layers[:num_layer]:
  180. hidden_states = encoder_layer(hidden_states)
  181. return hidden_states
  182. class CLIPVisionTransformer(nn.Module):
  183. def __init__(self,
  184. config: CLIPVisionConfig,
  185. quant_config: Optional[QuantizationConfig] = None):
  186. super().__init__()
  187. self.config = config
  188. embed_dim = config.hidden_size
  189. self.embeddings = CLIPVisionEmbeddings(config)
  190. # NOTE: This typo of "layrnorm" is not fixed on purpose to match
  191. # the original transformers code and name of the model weights.
  192. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  193. self.encoder = CLIPEncoder(config=config, quant_config=quant_config)
  194. def forward(
  195. self,
  196. pixel_values: torch.Tensor,
  197. vision_feature_layer: int = -1,
  198. ) -> torch.Tensor:
  199. hidden_states = self.embeddings(pixel_values)
  200. hidden_states = self.pre_layrnorm(hidden_states)
  201. hidden_states = self.encoder(inputs_embeds=hidden_states,
  202. vision_feature_layer=vision_feature_layer)
  203. return hidden_states
  204. class CLIPVisionModel(nn.Module):
  205. config_class = CLIPVisionConfig
  206. main_input_name = "pixel_values"
  207. def __init__(self,
  208. config: CLIPVisionConfig,
  209. quant_config: Optional[QuantizationConfig] = None):
  210. super().__init__()
  211. self.vision_model = CLIPVisionTransformer(config=config,
  212. quant_config=quant_config)
  213. def forward(self,
  214. pixel_values: Optional[torch.Tensor] = None,
  215. vision_feature_layer: int = -1):
  216. return self.vision_model(pixel_values=pixel_values,
  217. vision_feature_layer=vision_feature_layer)
  218. @property
  219. def device(self):
  220. return next(self.parameters()).device