intern_vit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. from typing import Iterable, Optional, Tuple
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from transformers import PretrainedConfig
  12. from xformers import ops as xops
  13. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  14. from aphrodite.modeling.layers.activation import get_act_fn
  15. from aphrodite.modeling.layers.layernorm import RMSNorm
  16. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  20. from aphrodite.quantization import QuantizationConfig
  21. NORM2FN = {
  22. 'rms_norm': RMSNorm,
  23. 'layer_norm': nn.LayerNorm,
  24. }
  25. class InternVisionEmbeddings(nn.Module):
  26. def __init__(self, config: PretrainedConfig):
  27. super().__init__()
  28. self.config = config
  29. self.embed_dim = config.hidden_size
  30. self.image_size = config.image_size
  31. self.patch_size = config.patch_size
  32. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  33. self.patch_embedding = nn.Conv2d(in_channels=3,
  34. out_channels=self.embed_dim,
  35. kernel_size=self.patch_size,
  36. stride=self.patch_size)
  37. self.num_patches = (self.image_size // self.patch_size)**2
  38. self.num_positions = self.num_patches + 1
  39. self.position_embedding = nn.Parameter(
  40. torch.randn(1, self.num_positions, self.embed_dim))
  41. def _get_pos_embed(self, pos_embed, H, W):
  42. target_dtype = pos_embed.dtype
  43. pos_embed = pos_embed.float().reshape(
  44. 1, self.image_size // self.patch_size,
  45. self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
  46. pos_embed = F.interpolate(pos_embed,
  47. size=(H, W),
  48. mode='bicubic',
  49. align_corners=False)
  50. pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
  51. 1).to(target_dtype)
  52. return pos_embed
  53. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  54. target_dtype = self.patch_embedding.weight.dtype
  55. patch_embeds = self.patch_embedding(pixel_values.to(
  56. target_dtype)) # shape = [*, channel, width, height]
  57. batch_size, _, height, width = patch_embeds.shape
  58. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  59. class_embeds = self.class_embedding.expand(batch_size, 1,
  60. -1).to(target_dtype)
  61. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  62. position_embedding = torch.cat([
  63. self.position_embedding[:, :1, :],
  64. self._get_pos_embed(self.position_embedding[:, 1:, :], height,
  65. width)
  66. ],
  67. dim=1)
  68. embeddings = embeddings + position_embedding.to(target_dtype)
  69. return embeddings
  70. class InternAttention(nn.Module):
  71. """Multi-headed attention from 'Attention Is All You Need' paper"""
  72. def __init__(
  73. self,
  74. config: PretrainedConfig,
  75. quant_config: Optional[QuantizationConfig] = None,
  76. ):
  77. super().__init__()
  78. self.config = config
  79. self.embed_dim = config.hidden_size
  80. self.num_heads = config.num_attention_heads
  81. self.head_dim = self.embed_dim // self.num_heads
  82. if self.head_dim * self.num_heads != self.embed_dim:
  83. raise ValueError(
  84. f'embed_dim must be divisible by num_heads '
  85. f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
  86. f' {self.num_heads}).')
  87. self.scale = self.head_dim**-0.5
  88. self.qkv = QKVParallelLinear(
  89. self.embed_dim,
  90. self.head_dim,
  91. self.num_heads,
  92. bias=config.qkv_bias,
  93. quant_config=quant_config,
  94. )
  95. self.qk_normalization = config.qk_normalization
  96. if self.qk_normalization:
  97. self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  98. self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  99. self.proj = RowParallelLinear(
  100. self.embed_dim,
  101. self.embed_dim,
  102. quant_config=quant_config,
  103. )
  104. self.tp_size = get_tensor_model_parallel_world_size()
  105. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  106. def forward(self, x):
  107. B, N, C = x.shape
  108. qkv, _ = self.qkv(x)
  109. q, k, v = qkv.chunk(3, dim=-1)
  110. q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
  111. k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
  112. v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
  113. if self.qk_normalization:
  114. B_, N_, H_, D_ = q.shape
  115. q = self.q_norm.forward_native(q.flatten(-2,
  116. -1)).view(B_, N_, H_, D_)
  117. k = self.k_norm.forward_native(k.flatten(-2,
  118. -1)).view(B_, N_, H_, D_)
  119. x = xops.memory_efficient_attention_forward(
  120. q,
  121. k,
  122. v,
  123. scale=self.scale,
  124. )
  125. x = x.view(B, N, -1)
  126. x, _ = self.proj(x)
  127. return x
  128. class InternMLP(nn.Module):
  129. def __init__(self,
  130. config: PretrainedConfig,
  131. quant_config: Optional[QuantizationConfig] = None):
  132. super().__init__()
  133. self.config = config
  134. self.activation_fn = get_act_fn(config.hidden_act)
  135. self.fc1 = ColumnParallelLinear(config.hidden_size,
  136. config.intermediate_size,
  137. bias=True,
  138. quant_config=quant_config)
  139. self.fc2 = RowParallelLinear(config.intermediate_size,
  140. config.hidden_size,
  141. bias=True,
  142. quant_config=quant_config)
  143. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  144. hidden_states, _ = self.fc1(hidden_states)
  145. hidden_states = self.activation_fn(hidden_states)
  146. hidden_states, _ = self.fc2(hidden_states)
  147. return hidden_states
  148. class InternVisionEncoderLayer(nn.Module):
  149. def __init__(self,
  150. config: PretrainedConfig,
  151. quant_config: Optional[QuantizationConfig] = None):
  152. super().__init__()
  153. self.embed_dim = config.hidden_size
  154. self.intermediate_size = config.intermediate_size
  155. self.norm_type = config.norm_type
  156. self.attn = InternAttention(config, quant_config=quant_config)
  157. self.mlp = InternMLP(config, quant_config=quant_config)
  158. self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
  159. eps=config.layer_norm_eps)
  160. self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
  161. eps=config.layer_norm_eps)
  162. self.ls1 = nn.Parameter(config.initializer_factor *
  163. torch.ones(self.embed_dim))
  164. self.ls2 = nn.Parameter(config.initializer_factor *
  165. torch.ones(self.embed_dim))
  166. def forward(
  167. self,
  168. hidden_states: torch.Tensor,
  169. ):
  170. hidden_states = hidden_states + self.attn(
  171. self.norm1(hidden_states)) * self.ls1
  172. hidden_states = hidden_states + self.mlp(
  173. self.norm2(hidden_states)) * self.ls2
  174. return hidden_states
  175. class InternVisionEncoder(nn.Module):
  176. def __init__(self,
  177. config: PretrainedConfig,
  178. quant_config: Optional[QuantizationConfig] = None,
  179. num_hidden_layers_override: Optional[int] = None):
  180. super().__init__()
  181. self.config = config
  182. if num_hidden_layers_override is None:
  183. num_hidden_layers = config.num_hidden_layers
  184. else:
  185. num_hidden_layers = num_hidden_layers_override
  186. self.layers = nn.ModuleList([
  187. InternVisionEncoderLayer(config=config, quant_config=quant_config)
  188. for _ in range(num_hidden_layers)
  189. ])
  190. def forward(self, inputs_embeds: torch.Tensor):
  191. hidden_states = inputs_embeds
  192. for encoder_layer in self.layers:
  193. hidden_states = encoder_layer(hidden_states)
  194. return hidden_states
  195. class InternVisionModel(nn.Module):
  196. def __init__(self,
  197. config: PretrainedConfig,
  198. quant_config: Optional[QuantizationConfig] = None,
  199. num_hidden_layers_override: Optional[int] = None):
  200. super().__init__()
  201. self.config = config
  202. self.embeddings = InternVisionEmbeddings(config)
  203. self.encoder = InternVisionEncoder(
  204. config=config,
  205. quant_config=quant_config,
  206. num_hidden_layers_override=num_hidden_layers_override)
  207. def resize_pos_embeddings(self, old_size, new_size, patch_size):
  208. pos_emb = self.embeddings.position_embedding
  209. _, num_positions, embed_dim = pos_emb.shape
  210. cls_emb = pos_emb[:, :1, :]
  211. pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
  212. old_size // patch_size,
  213. -1).permute(0, 3, 1, 2)
  214. pos_emb = F.interpolate(pos_emb.float(),
  215. size=new_size // patch_size,
  216. mode='bicubic',
  217. align_corners=False)
  218. pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
  219. -1).permute(0, 2, 1)
  220. pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
  221. self.embeddings.position_embedding = nn.Parameter(pos_emb)
  222. self.embeddings.image_size = new_size
  223. def get_input_embeddings(self):
  224. return self.embeddings
  225. def forward(
  226. self,
  227. pixel_values: Optional[torch.Tensor] = None,
  228. pixel_embeds: Optional[torch.Tensor] = None,
  229. ) -> torch.FloatTensor:
  230. if pixel_values is None and pixel_embeds is None:
  231. raise ValueError(
  232. 'You have to specify pixel_values or pixel_embeds')
  233. if pixel_embeds is not None:
  234. hidden_states = pixel_embeds
  235. elif pixel_values is not None:
  236. if pixel_values.ndim == 4:
  237. hidden_states = self.embeddings(pixel_values)
  238. else:
  239. raise ValueError(
  240. f'wrong pixel_values size: {pixel_values.shape}')
  241. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  242. return encoder_outputs
  243. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  244. params_dict = dict(self.named_parameters())
  245. for name, loaded_weight in weights:
  246. param = params_dict[name]
  247. weight_loader = getattr(param, "weight_loader",
  248. default_weight_loader)
  249. weight_loader(param, loaded_weight)