1
0

clip.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. """Minimal implementation of CLIPVisionModel intended to be only used
  2. within a vision language model."""
  3. from typing import Iterable, List, Optional, Tuple, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from PIL import Image
  8. from transformers import CLIPVisionConfig
  9. from transformers.models.clip.modeling_clip import CLIPSdpaAttention
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.sequence import SequenceData
  12. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  13. from aphrodite.inputs import LLMInputs
  14. from aphrodite.modeling.layers.activation import get_act_fn
  15. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  16. QKVParallelLinear,
  17. RowParallelLinear)
  18. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  19. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  20. repeat_and_pad_placeholder_tokens)
  21. from aphrodite.quantization import QuantizationConfig
  22. try:
  23. from xformers import ops as xops
  24. USE_XFORMERS_OPS = True
  25. except ImportError:
  26. USE_XFORMERS_OPS = False
  27. def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  28. assert image_size % patch_size == 0
  29. return image_size // patch_size
  30. def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
  31. grid_length = get_clip_patch_grid_length(image_size=image_size,
  32. patch_size=patch_size)
  33. return grid_length * grid_length
  34. def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
  35. return get_clip_num_patches(image_size=hf_config.image_size,
  36. patch_size=hf_config.patch_size) + 1
  37. def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
  38. return get_clip_image_feature_size(hf_config)
  39. def dummy_seq_data_for_clip(
  40. hf_config: CLIPVisionConfig,
  41. seq_len: int,
  42. num_images: int,
  43. *,
  44. image_token_id: int,
  45. image_feature_size_override: Optional[int] = None,
  46. ):
  47. if image_feature_size_override is None:
  48. image_feature_size = get_clip_image_feature_size(hf_config)
  49. else:
  50. image_feature_size = image_feature_size_override
  51. return SequenceData.from_token_counts(
  52. (image_token_id, image_feature_size * num_images),
  53. (0, seq_len - image_feature_size * num_images),
  54. )
  55. def dummy_image_for_clip(
  56. hf_config: CLIPVisionConfig,
  57. num_images: int,
  58. *,
  59. image_width_override: Optional[int] = None,
  60. image_height_override: Optional[int] = None,
  61. ):
  62. width = height = hf_config.image_size
  63. if image_width_override is not None:
  64. width = image_width_override
  65. if image_height_override is not None:
  66. height = image_height_override
  67. image = Image.new("RGB", (width, height), color=0)
  68. return {"image": image if num_images == 1 else [image] * num_images}
  69. def dummy_video_for_clip(
  70. hf_config: CLIPVisionConfig,
  71. num_frames: int,
  72. *,
  73. image_width_override: Optional[int] = None,
  74. image_height_override: Optional[int] = None,
  75. ):
  76. pil_frame = dummy_image_for_clip(
  77. hf_config,
  78. num_images=1,
  79. image_width_override=image_width_override,
  80. image_height_override=image_height_override)
  81. np_frame = np.array(pil_frame["image"])
  82. mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
  83. mm_data = {"video": mm_data_per_video}
  84. return mm_data
  85. def input_processor_for_clip(
  86. model_config: ModelConfig,
  87. hf_config: CLIPVisionConfig,
  88. llm_inputs: LLMInputs,
  89. *,
  90. image_token_id: int,
  91. image_feature_size_override: Optional[Union[int, List[int]]] = None,
  92. ):
  93. multi_modal_data = llm_inputs.get("multi_modal_data")
  94. if multi_modal_data is None or "image" not in multi_modal_data:
  95. return llm_inputs
  96. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  97. if image_feature_size_override is None:
  98. image_data = multi_modal_data["image"]
  99. if isinstance(image_data, Image.Image):
  100. image_feature_size = get_clip_image_feature_size(hf_config)
  101. elif isinstance(image_data, torch.Tensor):
  102. num_images, image_feature_size, hidden_size = image_data.shape
  103. else:
  104. raise TypeError(f"Invalid image type: {type(image_data)}")
  105. else:
  106. image_feature_size = image_feature_size_override
  107. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  108. tokenizer,
  109. llm_inputs.get("prompt"),
  110. llm_inputs["prompt_token_ids"],
  111. placeholder_token_id=image_token_id,
  112. repeat_count=image_feature_size,
  113. )
  114. # NOTE: Create a defensive copy of the original inputs
  115. return LLMInputs(prompt_token_ids=new_token_ids,
  116. prompt=new_prompt,
  117. multi_modal_data=multi_modal_data)
  118. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
  119. class CLIPVisionEmbeddings(nn.Module):
  120. def __init__(self, config: CLIPVisionConfig):
  121. super().__init__()
  122. self.config = config
  123. self.embed_dim = config.hidden_size
  124. self.image_size = config.image_size
  125. self.patch_size = config.patch_size
  126. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  127. self.patch_embedding = nn.Conv2d(
  128. in_channels=config.num_channels,
  129. out_channels=self.embed_dim,
  130. kernel_size=self.patch_size,
  131. stride=self.patch_size,
  132. bias=False,
  133. )
  134. self.num_patches = get_clip_num_patches(image_size=self.image_size,
  135. patch_size=self.patch_size)
  136. self.num_positions = self.num_patches + 1
  137. self.position_embedding = nn.Embedding(self.num_positions,
  138. self.embed_dim)
  139. self.register_buffer("position_ids",
  140. torch.arange(self.num_positions).expand((1, -1)),
  141. persistent=False)
  142. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  143. batch_size = pixel_values.shape[0]
  144. target_dtype = self.patch_embedding.weight.dtype
  145. patch_embeds = self.patch_embedding(pixel_values.to(
  146. dtype=target_dtype)) # shape = [*, width, grid, grid]
  147. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  148. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  149. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  150. embeddings = embeddings + self.position_embedding(self.position_ids)
  151. return embeddings
  152. class CLIPParallelAttention(nn.Module):
  153. """Multi-headed attention from 'Attention Is All You Need' paper"""
  154. def __init__(
  155. self,
  156. config: CLIPVisionConfig,
  157. quant_config: Optional[QuantizationConfig] = None,
  158. ):
  159. super().__init__()
  160. self.config = config
  161. self.embed_dim = config.hidden_size
  162. self.num_heads = config.num_attention_heads
  163. self.head_dim = self.embed_dim // self.num_heads
  164. if self.head_dim * self.num_heads != self.embed_dim:
  165. raise ValueError(
  166. "embed_dim must be divisible by num_heads "
  167. f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
  168. f" {self.num_heads}).")
  169. self.scale = self.head_dim**-0.5
  170. self.dropout = config.attention_dropout
  171. self.qkv_proj = QKVParallelLinear(
  172. hidden_size=self.embed_dim,
  173. head_size=self.head_dim,
  174. total_num_heads=self.num_heads,
  175. quant_config=quant_config,
  176. )
  177. self.out_proj = RowParallelLinear(
  178. input_size=self.embed_dim,
  179. output_size=self.embed_dim,
  180. quant_config=quant_config,
  181. )
  182. self.tp_size = get_tensor_model_parallel_world_size()
  183. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  184. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  185. return tensor.view(bsz, seq_len, self.num_heads,
  186. self.head_dim).transpose(1, 2).contiguous()
  187. def forward(
  188. self,
  189. hidden_states: torch.Tensor,
  190. ):
  191. """Input shape: Batch x Time x Channel"""
  192. bsz, tgt_len, _ = hidden_states.size()
  193. qkv_states, _ = self.qkv_proj(hidden_states)
  194. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  195. query_states = query_states.view(bsz, tgt_len,
  196. self.num_heads_per_partition,
  197. self.head_dim)
  198. key_states = key_states.view(bsz, tgt_len,
  199. self.num_heads_per_partition,
  200. self.head_dim)
  201. value_states = value_states.view(bsz, tgt_len,
  202. self.num_heads_per_partition,
  203. self.head_dim)
  204. out = xops.memory_efficient_attention_forward(query_states,
  205. key_states,
  206. value_states,
  207. p=self.dropout,
  208. scale=self.scale)
  209. out = out.view(bsz, tgt_len, -1)
  210. attn_output, _ = self.out_proj(out)
  211. return attn_output, None
  212. class CLIPMLP(nn.Module):
  213. def __init__(self,
  214. config: CLIPVisionConfig,
  215. quant_config: Optional[QuantizationConfig] = None):
  216. super().__init__()
  217. self.config = config
  218. self.activation_fn = get_act_fn(config.hidden_act)
  219. self.fc1 = ColumnParallelLinear(config.hidden_size,
  220. config.intermediate_size,
  221. bias=True,
  222. quant_config=quant_config)
  223. self.fc2 = RowParallelLinear(config.intermediate_size,
  224. config.hidden_size,
  225. bias=True,
  226. quant_config=quant_config)
  227. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  228. hidden_states, _ = self.fc1(hidden_states)
  229. hidden_states = self.activation_fn(hidden_states)
  230. hidden_states, _ = self.fc2(hidden_states)
  231. return hidden_states
  232. class CLIPEncoderLayer(nn.Module):
  233. def __init__(self,
  234. config: CLIPVisionConfig,
  235. quant_config: Optional[QuantizationConfig] = None):
  236. super().__init__()
  237. num_heads = config.num_attention_heads
  238. tp_size = get_tensor_model_parallel_world_size()
  239. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  240. self.self_attn = CLIPParallelAttention(config,
  241. quant_config=quant_config)
  242. else:
  243. self.self_attn = CLIPSdpaAttention(config)
  244. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  245. eps=config.layer_norm_eps)
  246. self.mlp = CLIPMLP(config, quant_config=quant_config)
  247. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  248. eps=config.layer_norm_eps)
  249. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  250. residual = hidden_states
  251. hidden_states = self.layer_norm1(hidden_states)
  252. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  253. hidden_states = residual + hidden_states
  254. residual = hidden_states
  255. hidden_states = self.layer_norm2(hidden_states)
  256. hidden_states = self.mlp(hidden_states)
  257. hidden_states = residual + hidden_states
  258. return hidden_states
  259. class CLIPEncoder(nn.Module):
  260. """
  261. Transformer encoder consisting of `config.num_hidden_layers` self
  262. attention layers. Each layer is a [`CLIPEncoderLayer`].
  263. Args:
  264. config: CLIPConfig
  265. """
  266. def __init__(self,
  267. config: CLIPVisionConfig,
  268. quant_config: Optional[QuantizationConfig] = None,
  269. num_hidden_layers_override: Optional[int] = None):
  270. super().__init__()
  271. self.config = config
  272. if num_hidden_layers_override is None:
  273. num_hidden_layers = config.num_hidden_layers
  274. else:
  275. num_hidden_layers = num_hidden_layers_override
  276. self.layers = nn.ModuleList([
  277. CLIPEncoderLayer(config=config, quant_config=quant_config)
  278. for _ in range(num_hidden_layers)
  279. ])
  280. def forward(self, inputs_embeds: torch.Tensor):
  281. hidden_states = inputs_embeds
  282. for encoder_layer in self.layers:
  283. hidden_states = encoder_layer(hidden_states)
  284. return hidden_states
  285. class CLIPVisionTransformer(nn.Module):
  286. def __init__(self,
  287. config: CLIPVisionConfig,
  288. quant_config: Optional[QuantizationConfig] = None,
  289. num_hidden_layers_override: Optional[int] = None):
  290. super().__init__()
  291. self.config = config
  292. embed_dim = config.hidden_size
  293. self.embeddings = CLIPVisionEmbeddings(config)
  294. # NOTE: This typo of "layrnorm" is not fixed on purpose to match
  295. # the original transformers code and name of the model weights.
  296. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  297. self.encoder = CLIPEncoder(
  298. config=config,
  299. quant_config=quant_config,
  300. num_hidden_layers_override=num_hidden_layers_override)
  301. def forward(
  302. self,
  303. pixel_values: torch.Tensor,
  304. ) -> torch.Tensor:
  305. hidden_states = self.embeddings(pixel_values)
  306. hidden_states = self.pre_layrnorm(hidden_states)
  307. hidden_states = self.encoder(inputs_embeds=hidden_states)
  308. return hidden_states
  309. class CLIPVisionModel(nn.Module):
  310. config_class = CLIPVisionConfig
  311. main_input_name = "pixel_values"
  312. def __init__(self,
  313. config: CLIPVisionConfig,
  314. quant_config: Optional[QuantizationConfig] = None,
  315. num_hidden_layers_override: Optional[int] = None):
  316. super().__init__()
  317. tp_size = get_tensor_model_parallel_world_size()
  318. num_heads = config.num_attention_heads
  319. self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
  320. self.vision_model = CLIPVisionTransformer(
  321. config=config,
  322. quant_config=quant_config,
  323. num_hidden_layers_override=num_hidden_layers_override)
  324. def forward(self, pixel_values: Optional[torch.Tensor] = None):
  325. return self.vision_model(pixel_values=pixel_values)
  326. @property
  327. def device(self):
  328. return next(self.parameters()).device
  329. # TODO: Add prefix argument for filtering out weights to be loaded
  330. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  331. stacked_params_mapping = [
  332. # (param_name, shard_name, shard_id)
  333. ("qkv_proj", "q_proj", "q"),
  334. ("qkv_proj", "k_proj", "k"),
  335. ("qkv_proj", "v_proj", "v"),
  336. ] if self.shard_weight else []
  337. params_dict = dict(self.named_parameters())
  338. layer_count = len(self.vision_model.encoder.layers)
  339. for name, loaded_weight in weights:
  340. # post_layernorm is not needed in CLIPVisionModel
  341. if (name.startswith("vision_model.post_layernorm")
  342. and self.vision_model.post_layernorm is None):
  343. continue
  344. # omit layers when num_hidden_layers_override is set
  345. if name.startswith("vision_model.encoder.layers"):
  346. layer_idx = int(name.split(".")[3])
  347. if layer_idx >= layer_count:
  348. continue
  349. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  350. if weight_name not in name:
  351. continue
  352. param = params_dict[name.replace(weight_name, param_name)]
  353. weight_loader = param.weight_loader
  354. weight_loader(param, loaded_weight, shard_id)
  355. break
  356. else:
  357. param = params_dict[name]
  358. weight_loader = getattr(param, "weight_loader",
  359. default_weight_loader)
  360. weight_loader(param, loaded_weight)