clip.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. """Minimal implementation of CLIPVisionModel intended to be only used
  2. within a vision language model."""
  3. from array import array
  4. from typing import Iterable, List, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. from PIL import Image
  8. from transformers import CLIPVisionConfig
  9. from transformers.models.clip.modeling_clip import CLIPSdpaAttention
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  12. SequenceData)
  13. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  14. from aphrodite.inputs import LLMInputs
  15. from aphrodite.modeling.layers.activation import get_act_fn
  16. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  20. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  21. repeat_and_pad_placeholder_tokens)
  22. from aphrodite.quantization import QuantizationConfig
  23. try:
  24. from xformers import ops as xops
  25. USE_XFORMERS_OPS = True
  26. except ImportError:
  27. USE_XFORMERS_OPS = False
  28. def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  29. assert image_size % patch_size == 0
  30. return image_size // patch_size
  31. def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
  32. grid_length = get_clip_patch_grid_length(image_size=image_size,
  33. patch_size=patch_size)
  34. return grid_length * grid_length
  35. def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
  36. return get_clip_num_patches(image_size=hf_config.image_size,
  37. patch_size=hf_config.patch_size) + 1
  38. def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
  39. return get_clip_image_feature_size(hf_config)
  40. def dummy_seq_data_for_clip(
  41. hf_config: CLIPVisionConfig,
  42. seq_len: int,
  43. num_images: int,
  44. *,
  45. image_token_id: int,
  46. image_feature_size_override: Optional[int] = None,
  47. ):
  48. if image_feature_size_override is None:
  49. image_feature_size = get_clip_image_feature_size(hf_config)
  50. else:
  51. image_feature_size = image_feature_size_override
  52. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  53. [image_token_id]) * image_feature_size * num_images
  54. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  55. [0]) * (seq_len - image_feature_size * num_images)
  56. return SequenceData(token_ids)
  57. def dummy_image_for_clip(
  58. hf_config: CLIPVisionConfig,
  59. num_images: int,
  60. *,
  61. image_width_override: Optional[int] = None,
  62. image_height_override: Optional[int] = None,
  63. ):
  64. width = height = hf_config.image_size
  65. if image_width_override is not None:
  66. width = image_width_override
  67. if image_height_override is not None:
  68. height = image_height_override
  69. image = Image.new("RGB", (width, height), color=0)
  70. return {"image": image if num_images == 1 else [image] * num_images}
  71. def input_processor_for_clip(
  72. model_config: ModelConfig,
  73. hf_config: CLIPVisionConfig,
  74. llm_inputs: LLMInputs,
  75. *,
  76. image_token_id: int,
  77. image_feature_size_override: Optional[Union[int, List[int]]] = None,
  78. ):
  79. multi_modal_data = llm_inputs.get("multi_modal_data")
  80. if multi_modal_data is None or "image" not in multi_modal_data:
  81. return llm_inputs
  82. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  83. if image_feature_size_override is None:
  84. image_data = multi_modal_data["image"]
  85. if isinstance(image_data, Image.Image):
  86. image_feature_size = get_clip_image_feature_size(hf_config)
  87. elif isinstance(image_data, torch.Tensor):
  88. num_images, image_feature_size, hidden_size = image_data.shape
  89. else:
  90. raise TypeError(f"Invalid image type: {type(image_data)}")
  91. else:
  92. image_feature_size = image_feature_size_override
  93. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  94. tokenizer,
  95. llm_inputs.get("prompt"),
  96. llm_inputs["prompt_token_ids"],
  97. placeholder_token_id=image_token_id,
  98. repeat_count=image_feature_size,
  99. )
  100. # NOTE: Create a defensive copy of the original inputs
  101. return LLMInputs(prompt_token_ids=new_token_ids,
  102. prompt=new_prompt,
  103. multi_modal_data=multi_modal_data)
  104. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
  105. class CLIPVisionEmbeddings(nn.Module):
  106. def __init__(self, config: CLIPVisionConfig):
  107. super().__init__()
  108. self.config = config
  109. self.embed_dim = config.hidden_size
  110. self.image_size = config.image_size
  111. self.patch_size = config.patch_size
  112. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  113. self.patch_embedding = nn.Conv2d(
  114. in_channels=config.num_channels,
  115. out_channels=self.embed_dim,
  116. kernel_size=self.patch_size,
  117. stride=self.patch_size,
  118. bias=False,
  119. )
  120. self.num_patches = get_clip_num_patches(image_size=self.image_size,
  121. patch_size=self.patch_size)
  122. self.num_positions = self.num_patches + 1
  123. self.position_embedding = nn.Embedding(self.num_positions,
  124. self.embed_dim)
  125. self.register_buffer("position_ids",
  126. torch.arange(self.num_positions).expand((1, -1)),
  127. persistent=False)
  128. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  129. batch_size = pixel_values.shape[0]
  130. target_dtype = self.patch_embedding.weight.dtype
  131. patch_embeds = self.patch_embedding(pixel_values.to(
  132. dtype=target_dtype)) # shape = [*, width, grid, grid]
  133. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  134. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  135. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  136. embeddings = embeddings + self.position_embedding(self.position_ids)
  137. return embeddings
  138. class CLIPParallelAttention(nn.Module):
  139. """Multi-headed attention from 'Attention Is All You Need' paper"""
  140. def __init__(
  141. self,
  142. config: CLIPVisionConfig,
  143. quant_config: Optional[QuantizationConfig] = None,
  144. ):
  145. super().__init__()
  146. self.config = config
  147. self.embed_dim = config.hidden_size
  148. self.num_heads = config.num_attention_heads
  149. self.head_dim = self.embed_dim // self.num_heads
  150. if self.head_dim * self.num_heads != self.embed_dim:
  151. raise ValueError(
  152. "embed_dim must be divisible by num_heads "
  153. f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
  154. f" {self.num_heads}).")
  155. self.scale = self.head_dim**-0.5
  156. self.dropout = config.attention_dropout
  157. self.qkv_proj = QKVParallelLinear(
  158. hidden_size=self.embed_dim,
  159. head_size=self.head_dim,
  160. total_num_heads=self.num_heads,
  161. quant_config=quant_config,
  162. )
  163. self.out_proj = RowParallelLinear(
  164. input_size=self.embed_dim,
  165. output_size=self.embed_dim,
  166. quant_config=quant_config,
  167. )
  168. self.tp_size = get_tensor_model_parallel_world_size()
  169. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  170. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  171. return tensor.view(bsz, seq_len, self.num_heads,
  172. self.head_dim).transpose(1, 2).contiguous()
  173. def forward(
  174. self,
  175. hidden_states: torch.Tensor,
  176. ):
  177. """Input shape: Batch x Time x Channel"""
  178. bsz, tgt_len, _ = hidden_states.size()
  179. qkv_states, _ = self.qkv_proj(hidden_states)
  180. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  181. query_states = query_states.view(bsz, tgt_len,
  182. self.num_heads_per_partition,
  183. self.head_dim)
  184. key_states = key_states.view(bsz, tgt_len,
  185. self.num_heads_per_partition,
  186. self.head_dim)
  187. value_states = value_states.view(bsz, tgt_len,
  188. self.num_heads_per_partition,
  189. self.head_dim)
  190. out = xops.memory_efficient_attention_forward(query_states,
  191. key_states,
  192. value_states,
  193. p=self.dropout,
  194. scale=self.scale)
  195. out = out.view(bsz, tgt_len, -1)
  196. attn_output, _ = self.out_proj(out)
  197. return attn_output, None
  198. class CLIPMLP(nn.Module):
  199. def __init__(self,
  200. config: CLIPVisionConfig,
  201. quant_config: Optional[QuantizationConfig] = None):
  202. super().__init__()
  203. self.config = config
  204. self.activation_fn = get_act_fn(config.hidden_act)
  205. self.fc1 = ColumnParallelLinear(config.hidden_size,
  206. config.intermediate_size,
  207. bias=True,
  208. quant_config=quant_config)
  209. self.fc2 = RowParallelLinear(config.intermediate_size,
  210. config.hidden_size,
  211. bias=True,
  212. quant_config=quant_config)
  213. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  214. hidden_states, _ = self.fc1(hidden_states)
  215. hidden_states = self.activation_fn(hidden_states)
  216. hidden_states, _ = self.fc2(hidden_states)
  217. return hidden_states
  218. class CLIPEncoderLayer(nn.Module):
  219. def __init__(self,
  220. config: CLIPVisionConfig,
  221. quant_config: Optional[QuantizationConfig] = None):
  222. super().__init__()
  223. num_heads = config.num_attention_heads
  224. tp_size = get_tensor_model_parallel_world_size()
  225. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  226. self.self_attn = CLIPParallelAttention(config,
  227. quant_config=quant_config)
  228. else:
  229. self.self_attn = CLIPSdpaAttention(config)
  230. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  231. eps=config.layer_norm_eps)
  232. self.mlp = CLIPMLP(config, quant_config=quant_config)
  233. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  234. eps=config.layer_norm_eps)
  235. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  236. residual = hidden_states
  237. hidden_states = self.layer_norm1(hidden_states)
  238. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  239. hidden_states = residual + hidden_states
  240. residual = hidden_states
  241. hidden_states = self.layer_norm2(hidden_states)
  242. hidden_states = self.mlp(hidden_states)
  243. hidden_states = residual + hidden_states
  244. return hidden_states
  245. class CLIPEncoder(nn.Module):
  246. """
  247. Transformer encoder consisting of `config.num_hidden_layers` self
  248. attention layers. Each layer is a [`CLIPEncoderLayer`].
  249. Args:
  250. config: CLIPConfig
  251. """
  252. def __init__(self,
  253. config: CLIPVisionConfig,
  254. quant_config: Optional[QuantizationConfig] = None,
  255. num_hidden_layers_override: Optional[int] = None):
  256. super().__init__()
  257. self.config = config
  258. if num_hidden_layers_override is None:
  259. num_hidden_layers = config.num_hidden_layers
  260. else:
  261. num_hidden_layers = num_hidden_layers_override
  262. self.layers = nn.ModuleList([
  263. CLIPEncoderLayer(config=config, quant_config=quant_config)
  264. for _ in range(num_hidden_layers)
  265. ])
  266. def forward(self, inputs_embeds: torch.Tensor):
  267. hidden_states = inputs_embeds
  268. for encoder_layer in self.layers:
  269. hidden_states = encoder_layer(hidden_states)
  270. return hidden_states
  271. class CLIPVisionTransformer(nn.Module):
  272. def __init__(self,
  273. config: CLIPVisionConfig,
  274. quant_config: Optional[QuantizationConfig] = None,
  275. num_hidden_layers_override: Optional[int] = None):
  276. super().__init__()
  277. self.config = config
  278. embed_dim = config.hidden_size
  279. self.embeddings = CLIPVisionEmbeddings(config)
  280. # NOTE: This typo of "layrnorm" is not fixed on purpose to match
  281. # the original transformers code and name of the model weights.
  282. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  283. self.encoder = CLIPEncoder(
  284. config=config,
  285. quant_config=quant_config,
  286. num_hidden_layers_override=num_hidden_layers_override)
  287. def forward(
  288. self,
  289. pixel_values: torch.Tensor,
  290. ) -> torch.Tensor:
  291. hidden_states = self.embeddings(pixel_values)
  292. hidden_states = self.pre_layrnorm(hidden_states)
  293. hidden_states = self.encoder(inputs_embeds=hidden_states)
  294. return hidden_states
  295. class CLIPVisionModel(nn.Module):
  296. config_class = CLIPVisionConfig
  297. main_input_name = "pixel_values"
  298. def __init__(self,
  299. config: CLIPVisionConfig,
  300. quant_config: Optional[QuantizationConfig] = None,
  301. num_hidden_layers_override: Optional[int] = None):
  302. super().__init__()
  303. tp_size = get_tensor_model_parallel_world_size()
  304. num_heads = config.num_attention_heads
  305. self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
  306. self.vision_model = CLIPVisionTransformer(
  307. config=config,
  308. quant_config=quant_config,
  309. num_hidden_layers_override=num_hidden_layers_override)
  310. def forward(self, pixel_values: Optional[torch.Tensor] = None):
  311. return self.vision_model(pixel_values=pixel_values)
  312. @property
  313. def device(self):
  314. return next(self.parameters()).device
  315. # TODO: Add prefix argument for filtering out weights to be loaded
  316. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  317. stacked_params_mapping = [
  318. # (param_name, shard_name, shard_id)
  319. ("qkv_proj", "q_proj", "q"),
  320. ("qkv_proj", "k_proj", "k"),
  321. ("qkv_proj", "v_proj", "v"),
  322. ] if self.shard_weight else []
  323. params_dict = dict(self.named_parameters())
  324. layer_count = len(self.vision_model.encoder.layers)
  325. for name, loaded_weight in weights:
  326. # post_layernorm is not needed in CLIPVisionModel
  327. if "vision_model.post_layernorm" in name:
  328. continue
  329. # omit layers when num_hidden_layers_override is set
  330. if "vision_model.encoder.layers." in name:
  331. layer_idx = int(name.split(".")[3])
  332. if layer_idx >= layer_count:
  333. continue
  334. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  335. if weight_name not in name:
  336. continue
  337. param = params_dict[name.replace(weight_name, param_name)]
  338. weight_loader = param.weight_loader
  339. weight_loader(param, loaded_weight, shard_id)
  340. break
  341. else:
  342. param = params_dict[name]
  343. weight_loader = getattr(param, "weight_loader",
  344. default_weight_loader)
  345. weight_loader(param, loaded_weight)