intern_vit.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. from typing import Iterable, Optional, Tuple
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from transformers import PretrainedConfig
  12. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  13. from aphrodite.modeling.layers.activation import get_act_fn
  14. from aphrodite.modeling.layers.layernorm import RMSNorm
  15. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  16. QKVParallelLinear,
  17. RowParallelLinear)
  18. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  19. from aphrodite.quantization import QuantizationConfig
  20. try:
  21. from xformers import ops as xops
  22. USE_XFORMERS_OPS = True
  23. except ImportError:
  24. USE_XFORMERS_OPS = False
  25. NORM2FN = {
  26. 'rms_norm': RMSNorm,
  27. 'layer_norm': nn.LayerNorm,
  28. }
  29. class InternVisionEmbeddings(nn.Module):
  30. def __init__(self, config: PretrainedConfig):
  31. super().__init__()
  32. self.config = config
  33. self.embed_dim = config.hidden_size
  34. self.image_size = config.image_size
  35. self.patch_size = config.patch_size
  36. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  37. self.patch_embedding = nn.Conv2d(in_channels=3,
  38. out_channels=self.embed_dim,
  39. kernel_size=self.patch_size,
  40. stride=self.patch_size)
  41. self.num_patches = (self.image_size // self.patch_size)**2
  42. self.num_positions = self.num_patches + 1
  43. self.position_embedding = nn.Parameter(
  44. torch.randn(1, self.num_positions, self.embed_dim))
  45. def _get_pos_embed(self, pos_embed, H, W):
  46. target_dtype = pos_embed.dtype
  47. pos_embed = pos_embed.float().reshape(
  48. 1, self.image_size // self.patch_size,
  49. self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
  50. pos_embed = F.interpolate(pos_embed,
  51. size=(H, W),
  52. mode='bicubic',
  53. align_corners=False)
  54. pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
  55. 1).to(target_dtype)
  56. return pos_embed
  57. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  58. target_dtype = self.patch_embedding.weight.dtype
  59. patch_embeds = self.patch_embedding(pixel_values.to(
  60. target_dtype)) # shape = [*, channel, width, height]
  61. batch_size, _, height, width = patch_embeds.shape
  62. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  63. class_embeds = self.class_embedding.expand(batch_size, 1,
  64. -1).to(target_dtype)
  65. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  66. position_embedding = torch.cat([
  67. self.position_embedding[:, :1, :],
  68. self._get_pos_embed(self.position_embedding[:, 1:, :], height,
  69. width)
  70. ],
  71. dim=1)
  72. embeddings = embeddings + position_embedding.to(target_dtype)
  73. return embeddings
  74. class InternParallelAttention(nn.Module):
  75. """Multi-headed attention from 'Attention Is All You Need' paper"""
  76. def __init__(
  77. self,
  78. config: PretrainedConfig,
  79. quant_config: Optional[QuantizationConfig] = None,
  80. ):
  81. super().__init__()
  82. self.config = config
  83. self.embed_dim = config.hidden_size
  84. self.num_heads = config.num_attention_heads
  85. self.head_dim = self.embed_dim // self.num_heads
  86. if self.head_dim * self.num_heads != self.embed_dim:
  87. raise ValueError(
  88. f'embed_dim must be divisible by num_heads '
  89. f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
  90. f' {self.num_heads}).')
  91. self.scale = self.head_dim**-0.5
  92. self.qkv = QKVParallelLinear(
  93. self.embed_dim,
  94. self.head_dim,
  95. self.num_heads,
  96. bias=config.qkv_bias,
  97. quant_config=quant_config,
  98. )
  99. self.qk_normalization = config.qk_normalization
  100. if self.qk_normalization:
  101. self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  102. self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  103. self.proj = RowParallelLinear(
  104. self.embed_dim,
  105. self.embed_dim,
  106. quant_config=quant_config,
  107. )
  108. self.tp_size = get_tensor_model_parallel_world_size()
  109. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  110. def forward(self, x):
  111. B, N, C = x.shape
  112. qkv, _ = self.qkv(x)
  113. q, k, v = qkv.chunk(3, dim=-1)
  114. q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
  115. k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
  116. v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
  117. if self.qk_normalization:
  118. B_, N_, H_, D_ = q.shape
  119. q = self.q_norm.forward_native(q.flatten(-2,
  120. -1)).view(B_, N_, H_, D_)
  121. k = self.k_norm.forward_native(k.flatten(-2,
  122. -1)).view(B_, N_, H_, D_)
  123. x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
  124. x = x.view(B, N, -1)
  125. x, _ = self.proj(x)
  126. return x
  127. class InternSdpaAttention(nn.Module):
  128. """Multi-headed attention from 'Attention Is All You Need' paper"""
  129. def __init__(self, config: PretrainedConfig):
  130. super().__init__()
  131. self.config = config
  132. self.embed_dim = config.hidden_size
  133. self.num_heads = config.num_attention_heads
  134. self.head_dim = self.embed_dim // self.num_heads
  135. if self.head_dim * self.num_heads != self.embed_dim:
  136. raise ValueError(
  137. f'embed_dim must be divisible by num_heads '
  138. f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
  139. f' {self.num_heads}).')
  140. self.scale = self.head_dim**-0.5
  141. self.qkv = nn.Linear(self.embed_dim,
  142. 3 * self.embed_dim,
  143. bias=config.qkv_bias)
  144. self.qk_normalization = config.qk_normalization
  145. if self.qk_normalization:
  146. self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  147. self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  148. self.proj = nn.Linear(self.embed_dim, self.embed_dim)
  149. def forward(self, x):
  150. B, N, C = x.shape
  151. qkv = self.qkv(x)
  152. q, k, v = qkv.chunk(3, dim=-1)
  153. q = q.view(B, N, self.num_heads, self.head_dim)
  154. k = k.view(B, N, self.num_heads, self.head_dim)
  155. v = v.view(B, N, self.num_heads, self.head_dim)
  156. if self.qk_normalization:
  157. B_, N_, H_, D_ = q.shape
  158. q = self.q_norm.forward_native(q.flatten(-2,
  159. -1)).view(B_, N_, H_, D_)
  160. k = self.k_norm.forward_native(k.flatten(-2,
  161. -1)).view(B_, N_, H_, D_)
  162. q = q.transpose(1, 2)
  163. k = k.transpose(1, 2)
  164. v = v.transpose(1, 2)
  165. x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
  166. x = x.transpose(1, 2).view(B, N, -1)
  167. x = self.proj(x)
  168. return x
  169. class InternMLP(nn.Module):
  170. def __init__(self,
  171. config: PretrainedConfig,
  172. quant_config: Optional[QuantizationConfig] = None):
  173. super().__init__()
  174. self.config = config
  175. self.activation_fn = get_act_fn(config.hidden_act)
  176. self.fc1 = ColumnParallelLinear(config.hidden_size,
  177. config.intermediate_size,
  178. bias=True,
  179. quant_config=quant_config)
  180. self.fc2 = RowParallelLinear(config.intermediate_size,
  181. config.hidden_size,
  182. bias=True,
  183. quant_config=quant_config)
  184. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  185. hidden_states, _ = self.fc1(hidden_states)
  186. hidden_states = self.activation_fn(hidden_states)
  187. hidden_states, _ = self.fc2(hidden_states)
  188. return hidden_states
  189. class InternVisionEncoderLayer(nn.Module):
  190. def __init__(self,
  191. config: PretrainedConfig,
  192. quant_config: Optional[QuantizationConfig] = None):
  193. super().__init__()
  194. self.embed_dim = config.hidden_size
  195. self.intermediate_size = config.intermediate_size
  196. self.norm_type = config.norm_type
  197. # fallback to sdpa attention if tp unavailable
  198. tp_size = get_tensor_model_parallel_world_size()
  199. num_heads = config.num_attention_heads
  200. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  201. self.attn = InternParallelAttention(config,
  202. quant_config=quant_config)
  203. else:
  204. self.attn = InternSdpaAttention(config)
  205. self.mlp = InternMLP(config, quant_config=quant_config)
  206. self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
  207. eps=config.layer_norm_eps)
  208. self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
  209. eps=config.layer_norm_eps)
  210. self.ls1 = nn.Parameter(config.initializer_factor *
  211. torch.ones(self.embed_dim))
  212. self.ls2 = nn.Parameter(config.initializer_factor *
  213. torch.ones(self.embed_dim))
  214. def forward(
  215. self,
  216. hidden_states: torch.Tensor,
  217. ):
  218. hidden_states = hidden_states + self.attn(
  219. self.norm1(hidden_states)) * self.ls1
  220. hidden_states = hidden_states + self.mlp(
  221. self.norm2(hidden_states)) * self.ls2
  222. return hidden_states
  223. class InternVisionEncoder(nn.Module):
  224. def __init__(self,
  225. config: PretrainedConfig,
  226. quant_config: Optional[QuantizationConfig] = None,
  227. num_hidden_layers_override: Optional[int] = None):
  228. super().__init__()
  229. self.config = config
  230. if num_hidden_layers_override is None:
  231. num_hidden_layers = config.num_hidden_layers
  232. else:
  233. num_hidden_layers = num_hidden_layers_override
  234. self.layers = nn.ModuleList([
  235. InternVisionEncoderLayer(config=config, quant_config=quant_config)
  236. for _ in range(num_hidden_layers)
  237. ])
  238. def forward(self, inputs_embeds: torch.Tensor):
  239. hidden_states = inputs_embeds
  240. for encoder_layer in self.layers:
  241. hidden_states = encoder_layer(hidden_states)
  242. return hidden_states
  243. class InternVisionModel(nn.Module):
  244. def __init__(self,
  245. config: PretrainedConfig,
  246. quant_config: Optional[QuantizationConfig] = None,
  247. num_hidden_layers_override: Optional[int] = None):
  248. super().__init__()
  249. self.config = config
  250. self.embeddings = InternVisionEmbeddings(config)
  251. self.encoder = InternVisionEncoder(
  252. config=config,
  253. quant_config=quant_config,
  254. num_hidden_layers_override=num_hidden_layers_override)
  255. def resize_pos_embeddings(self, old_size, new_size, patch_size):
  256. pos_emb = self.embeddings.position_embedding
  257. _, num_positions, embed_dim = pos_emb.shape
  258. cls_emb = pos_emb[:, :1, :]
  259. pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
  260. old_size // patch_size,
  261. -1).permute(0, 3, 1, 2)
  262. pos_emb = F.interpolate(pos_emb.float(),
  263. size=new_size // patch_size,
  264. mode='bicubic',
  265. align_corners=False)
  266. pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
  267. -1).permute(0, 2, 1)
  268. pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
  269. self.embeddings.position_embedding = nn.Parameter(pos_emb)
  270. self.embeddings.image_size = new_size
  271. def get_input_embeddings(self):
  272. return self.embeddings
  273. def forward(
  274. self,
  275. pixel_values: Optional[torch.Tensor] = None,
  276. pixel_embeds: Optional[torch.Tensor] = None,
  277. ) -> torch.FloatTensor:
  278. if pixel_values is None and pixel_embeds is None:
  279. raise ValueError(
  280. 'You have to specify pixel_values or pixel_embeds')
  281. if pixel_embeds is not None:
  282. hidden_states = pixel_embeds
  283. elif pixel_values is not None:
  284. if pixel_values.ndim == 4:
  285. hidden_states = self.embeddings(pixel_values)
  286. else:
  287. raise ValueError(
  288. f'wrong pixel_values size: {pixel_values.shape}')
  289. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  290. return encoder_outputs
  291. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  292. params_dict = dict(self.named_parameters())
  293. for name, loaded_weight in weights:
  294. param = params_dict[name]
  295. weight_loader = getattr(param, "weight_loader",
  296. default_weight_loader)
  297. weight_loader(param, loaded_weight)