intern_vit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. from typing import Iterable, Optional, Tuple
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from transformers import PretrainedConfig
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.layernorm import RMSNorm
  14. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  15. RowParallelLinear)
  16. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  17. from aphrodite.quantization import QuantizationConfig
  18. NORM2FN = {
  19. 'rms_norm': RMSNorm,
  20. 'layer_norm': nn.LayerNorm,
  21. }
  22. class InternVisionEmbeddings(nn.Module):
  23. def __init__(self, config: PretrainedConfig):
  24. super().__init__()
  25. self.config = config
  26. self.embed_dim = config.hidden_size
  27. self.image_size = config.image_size
  28. self.patch_size = config.patch_size
  29. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  30. self.patch_embedding = nn.Conv2d(in_channels=3,
  31. out_channels=self.embed_dim,
  32. kernel_size=self.patch_size,
  33. stride=self.patch_size)
  34. self.num_patches = (self.image_size // self.patch_size)**2
  35. self.num_positions = self.num_patches + 1
  36. self.position_embedding = nn.Parameter(
  37. torch.randn(1, self.num_positions, self.embed_dim))
  38. def _get_pos_embed(self, pos_embed, H, W):
  39. target_dtype = pos_embed.dtype
  40. pos_embed = pos_embed.float().reshape(
  41. 1, self.image_size // self.patch_size,
  42. self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
  43. pos_embed = F.interpolate(pos_embed,
  44. size=(H, W),
  45. mode='bicubic',
  46. align_corners=False)
  47. pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
  48. 1).to(target_dtype)
  49. return pos_embed
  50. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  51. target_dtype = self.patch_embedding.weight.dtype
  52. patch_embeds = self.patch_embedding(pixel_values.to(
  53. target_dtype)) # shape = [*, channel, width, height]
  54. batch_size, _, height, width = patch_embeds.shape
  55. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  56. class_embeds = self.class_embedding.expand(batch_size, 1,
  57. -1).to(target_dtype)
  58. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  59. position_embedding = torch.cat([
  60. self.position_embedding[:, :1, :],
  61. self._get_pos_embed(self.position_embedding[:, 1:, :], height,
  62. width)
  63. ],
  64. dim=1)
  65. embeddings = embeddings + position_embedding.to(target_dtype)
  66. return embeddings
  67. class InternAttention(nn.Module):
  68. """Multi-headed attention from 'Attention Is All You Need' paper"""
  69. def __init__(self, config: PretrainedConfig):
  70. super().__init__()
  71. self.config = config
  72. self.embed_dim = config.hidden_size
  73. self.num_heads = config.num_attention_heads
  74. self.head_dim = self.embed_dim // self.num_heads
  75. if self.head_dim * self.num_heads != self.embed_dim:
  76. raise ValueError(
  77. f'embed_dim must be divisible by num_heads '
  78. f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
  79. f' {self.num_heads}).')
  80. self.scale = self.head_dim**-0.5
  81. self.qkv = nn.Linear(self.embed_dim,
  82. 3 * self.embed_dim,
  83. bias=config.qkv_bias)
  84. self.qk_normalization = config.qk_normalization
  85. if self.qk_normalization:
  86. self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  87. self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  88. self.proj = nn.Linear(self.embed_dim, self.embed_dim)
  89. def forward(self, x):
  90. B, N, C = x.shape
  91. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
  92. C // self.num_heads).permute(2, 0, 3, 1, 4)
  93. q, k, v = qkv.unbind(0)
  94. if self.qk_normalization:
  95. B_, H_, N_, D_ = q.shape
  96. q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
  97. -2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
  98. k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
  99. -2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
  100. x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
  101. x = x.transpose(1, 2).reshape(B, N, C)
  102. x = self.proj(x)
  103. return x
  104. class InternMLP(nn.Module):
  105. def __init__(self,
  106. config: PretrainedConfig,
  107. quant_config: Optional[QuantizationConfig] = None):
  108. super().__init__()
  109. self.config = config
  110. self.activation_fn = get_act_fn(config.hidden_act)
  111. self.fc1 = ColumnParallelLinear(config.hidden_size,
  112. config.intermediate_size,
  113. bias=True,
  114. quant_config=quant_config)
  115. self.fc2 = RowParallelLinear(config.intermediate_size,
  116. config.hidden_size,
  117. bias=True,
  118. quant_config=quant_config)
  119. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  120. hidden_states, _ = self.fc1(hidden_states)
  121. hidden_states = self.activation_fn(hidden_states)
  122. hidden_states, _ = self.fc2(hidden_states)
  123. return hidden_states
  124. class InternVisionEncoderLayer(nn.Module):
  125. def __init__(self,
  126. config: PretrainedConfig,
  127. quant_config: Optional[QuantizationConfig] = None):
  128. super().__init__()
  129. self.embed_dim = config.hidden_size
  130. self.intermediate_size = config.intermediate_size
  131. self.norm_type = config.norm_type
  132. self.attn = InternAttention(config)
  133. self.mlp = InternMLP(config, quant_config=quant_config)
  134. self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
  135. eps=config.layer_norm_eps)
  136. self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
  137. eps=config.layer_norm_eps)
  138. self.ls1 = nn.Parameter(config.initializer_factor *
  139. torch.ones(self.embed_dim))
  140. self.ls2 = nn.Parameter(config.initializer_factor *
  141. torch.ones(self.embed_dim))
  142. def forward(
  143. self,
  144. hidden_states: torch.Tensor,
  145. ):
  146. hidden_states = hidden_states + self.attn(
  147. self.norm1(hidden_states)) * self.ls1
  148. hidden_states = hidden_states + self.mlp(
  149. self.norm2(hidden_states)) * self.ls2
  150. return hidden_states
  151. class InternVisionEncoder(nn.Module):
  152. def __init__(self,
  153. config: PretrainedConfig,
  154. quant_config: Optional[QuantizationConfig] = None,
  155. num_hidden_layers_override: Optional[int] = None):
  156. super().__init__()
  157. self.config = config
  158. if num_hidden_layers_override is None:
  159. num_hidden_layers = config.num_hidden_layers
  160. else:
  161. num_hidden_layers = num_hidden_layers_override
  162. self.layers = nn.ModuleList([
  163. InternVisionEncoderLayer(config=config, quant_config=quant_config)
  164. for _ in range(num_hidden_layers)
  165. ])
  166. def forward(self, inputs_embeds: torch.Tensor):
  167. hidden_states = inputs_embeds
  168. for encoder_layer in self.layers:
  169. hidden_states = encoder_layer(hidden_states)
  170. return hidden_states
  171. class InternVisionModel(nn.Module):
  172. def __init__(self,
  173. config: PretrainedConfig,
  174. quant_config: Optional[QuantizationConfig] = None,
  175. num_hidden_layers_override: Optional[int] = None):
  176. super().__init__()
  177. self.config = config
  178. self.embeddings = InternVisionEmbeddings(config)
  179. self.encoder = InternVisionEncoder(
  180. config=config,
  181. quant_config=quant_config,
  182. num_hidden_layers_override=num_hidden_layers_override)
  183. def resize_pos_embeddings(self, old_size, new_size, patch_size):
  184. pos_emb = self.embeddings.position_embedding
  185. _, num_positions, embed_dim = pos_emb.shape
  186. cls_emb = pos_emb[:, :1, :]
  187. pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
  188. old_size // patch_size,
  189. -1).permute(0, 3, 1, 2)
  190. pos_emb = F.interpolate(pos_emb.float(),
  191. size=new_size // patch_size,
  192. mode='bicubic',
  193. align_corners=False)
  194. pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
  195. -1).permute(0, 2, 1)
  196. pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
  197. self.embeddings.position_embedding = nn.Parameter(pos_emb)
  198. self.embeddings.image_size = new_size
  199. def get_input_embeddings(self):
  200. return self.embeddings
  201. def forward(
  202. self,
  203. pixel_values: Optional[torch.Tensor] = None,
  204. pixel_embeds: Optional[torch.Tensor] = None,
  205. ) -> torch.FloatTensor:
  206. if pixel_values is None and pixel_embeds is None:
  207. raise ValueError(
  208. 'You have to specify pixel_values or pixel_embeds')
  209. if pixel_embeds is not None:
  210. hidden_states = pixel_embeds
  211. elif pixel_values is not None:
  212. if pixel_values.ndim == 4:
  213. hidden_states = self.embeddings(pixel_values)
  214. else:
  215. raise ValueError(
  216. f'wrong pixel_values size: {pixel_values.shape}')
  217. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  218. return encoder_outputs
  219. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  220. params_dict = dict(self.named_parameters())
  221. for name, loaded_weight in weights:
  222. param = params_dict[name]
  223. weight_loader = getattr(param, "weight_loader",
  224. default_weight_loader)
  225. weight_loader(param, loaded_weight)