resampler.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  5. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  6. #
  7. # Copyright 2023 The Qwen team.
  8. # Copyright 2023 The PygmalionAI team.
  9. # Copyright 2023 The vLLM team.
  10. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  11. #
  12. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  13. # and OPT implementations in this library. It has been modified from its
  14. # original forms to accommodate minor architectural differences compared
  15. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  16. #
  17. # Licensed under the Apache License, Version 2.0 (the "License");
  18. # you may not use this file except in compliance with the License.
  19. # You may obtain a copy of the License at
  20. #
  21. # http://www.apache.org/licenses/LICENSE-2.0
  22. #
  23. # Unless required by applicable law or agreed to in writing, software
  24. # distributed under the License is distributed on an "AS IS" BASIS,
  25. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. # See the License for the specific language governing permissions and
  27. # limitations under the License.
  28. """
  29. Shared resampler perceiver network used in multimodal models and
  30. related helpers for sincos positional embeddings.
  31. Example models: Qwen (Qwen-VL), Minicpmv2.0
  32. """
  33. import math
  34. from functools import partial
  35. from typing import Callable, Optional, Tuple, Union
  36. import numpy as np
  37. import torch
  38. import torch.nn.functional as F
  39. from torch import nn
  40. from torch.nn.init import trunc_normal_
  41. from aphrodite.modeling.layers.linear import ReplicatedLinear
  42. DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
  43. def get_abs_pos(
  44. abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, int]
  45. ) -> torch.Tensor:
  46. # abs_pos: L, C
  47. # tgt_size: (H, W)
  48. # return: M, C
  49. src_size = int(math.sqrt(abs_pos.size(0)))
  50. dtype = abs_pos.dtype
  51. if isinstance(tgt_size, int):
  52. tgt_size = (tgt_size, tgt_size)
  53. if src_size == tgt_size[0] and src_size == tgt_size[1]:
  54. return abs_pos
  55. return (
  56. F.interpolate(
  57. abs_pos.float()
  58. .reshape(1, src_size, src_size, -1)
  59. .permute(0, 3, 1, 2),
  60. size=(tgt_size[0], tgt_size[1]),
  61. mode="bicubic",
  62. align_corners=False,
  63. )
  64. .permute(0, 2, 3, 1)
  65. .flatten(0, 2)
  66. .to(dtype=dtype)
  67. )
  68. # sin/cos positional embedding helpers are adapted from:
  69. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  70. def get_1d_sincos_pos_embed_from_grid(
  71. embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0)
  72. ) -> torch.Tensor:
  73. """
  74. embed_dim: output dimension for each position
  75. pos: a list of positions to be encoded: size (M,) / (H, W)
  76. out: (M, D) / (H, W, D)
  77. """
  78. assert embed_dim % 2 == 0
  79. omega = np.arange(embed_dim // 2, dtype=np.float32)
  80. omega /= embed_dim / 2.0
  81. omega = 1.0 / 10000**omega # (D/2,)
  82. if version == (2, 0):
  83. pos = pos.reshape(-1) # (M,)
  84. out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
  85. emb_sin = np.sin(out) # (M, D/2)
  86. emb_cos = np.cos(out) # (M, D/2)
  87. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  88. else:
  89. out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
  90. emb_sin = np.sin(out) # (H, W, D/2)
  91. emb_cos = np.cos(out) # (H, W, D/2)
  92. emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
  93. return emb
  94. def get_2d_sincos_pos_embed_from_grid(
  95. embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0)
  96. ) -> torch.Tensor:
  97. assert embed_dim % 2 == 0
  98. # use half of dimensions to encode grid_h
  99. emb_h = get_1d_sincos_pos_embed_from_grid(
  100. embed_dim // 2, grid[0], version
  101. ) # (H*W, D/2) or (H, W, D/2)
  102. emb_w = get_1d_sincos_pos_embed_from_grid(
  103. embed_dim // 2, grid[1], version
  104. ) # (H*W, D/2) or (H, W, D/2)
  105. if version == (2, 0):
  106. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  107. else:
  108. emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
  109. return emb
  110. def get_2d_sincos_pos_embed(
  111. embed_dim: int,
  112. grid_size: Union[int, Tuple[int, int]],
  113. cls_token: bool = False,
  114. version: Tuple[int, int] = (2, 0),
  115. ) -> torch.Tensor:
  116. """
  117. grid_size: int of the grid height and width
  118. return:
  119. pos_embed: [grid_size*grid_size, embed_dim] or
  120. [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  121. """
  122. if isinstance(grid_size, int):
  123. grid_h_size, grid_w_size = grid_size, grid_size
  124. else:
  125. grid_h_size, grid_w_size = grid_size[0], grid_size[1]
  126. grid_h = np.arange(grid_h_size, dtype=np.float32)
  127. grid_w = np.arange(grid_w_size, dtype=np.float32)
  128. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  129. grid = np.stack(grid, axis=0)
  130. assert isinstance(grid, np.ndarray) and grid.shape == (
  131. 2,
  132. grid_h_size,
  133. grid_w_size,
  134. )
  135. if version == (2, 0):
  136. grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
  137. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  138. if cls_token:
  139. pos_embed = np.concatenate(
  140. [np.zeros([1, embed_dim]), pos_embed], axis=0
  141. )
  142. else:
  143. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  144. return pos_embed
  145. class BaseResampler(nn.Module):
  146. """
  147. A 2D perceiver-resampler network with one cross attention layers by
  148. (grid_size**2) learnable queries and 2d sincos pos_emb.
  149. Outputs:
  150. A tensor with the shape of (grid_size**2, embed_dim)
  151. """
  152. def __init__(
  153. self,
  154. num_queries: int,
  155. embed_dim: int,
  156. num_heads: int,
  157. kv_dim: Optional[int] = None,
  158. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  159. do_post_projection: bool = True,
  160. ) -> None:
  161. super().__init__()
  162. self.num_queries = num_queries
  163. self.embed_dim = embed_dim
  164. self.num_heads = num_heads
  165. self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
  166. trunc_normal_(self.query, std=0.02)
  167. if kv_dim is not None and kv_dim != embed_dim:
  168. self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
  169. else:
  170. # Maintain the same return value with ReplicatedLinear.forward
  171. self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
  172. nn.Identity()(*args, **kwargs),
  173. None,
  174. )
  175. self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  176. self.ln_q = norm_layer(embed_dim)
  177. self.ln_kv = norm_layer(embed_dim)
  178. self.do_post_projection = do_post_projection
  179. self.ln_post = norm_layer(embed_dim) if do_post_projection else None
  180. self.proj = (
  181. nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
  182. if do_post_projection
  183. else None
  184. )
  185. def _init_weights(self, m: nn.Module) -> None:
  186. if isinstance(m, nn.Linear):
  187. trunc_normal_(m.weight, std=0.02)
  188. if isinstance(m, nn.Linear) and m.bias is not None:
  189. nn.init.constant_(m.bias, 0)
  190. elif isinstance(m, nn.LayerNorm):
  191. nn.init.constant_(m.bias, 0)
  192. nn.init.constant_(m.weight, 1.0)
  193. def _repeat(self, query, N: int):
  194. return query.unsqueeze(1).repeat(1, N, 1)
  195. class Resampler2(BaseResampler):
  196. """Resampler-perceiver network to be used for a variety of model types,
  197. e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
  198. do_post_projection arg, which indicates whether or not there should be
  199. a post layer normalization and projector after the attention. This is
  200. present in minicpmv2.0, but not qwen-vl.
  201. """
  202. def __init__(
  203. self,
  204. grid_size: int,
  205. embed_dim: int,
  206. num_heads: int,
  207. kv_dim: Optional[int] = None,
  208. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  209. adaptive: bool = False,
  210. do_post_projection: bool = True,
  211. ) -> None:
  212. super().__init__(
  213. grid_size**2,
  214. embed_dim,
  215. num_heads,
  216. kv_dim,
  217. norm_layer,
  218. do_post_projection=do_post_projection,
  219. )
  220. self.adaptive = adaptive
  221. pos_embed_arr = get_2d_sincos_pos_embed(
  222. embed_dim, grid_size, version=(2, 0)
  223. )
  224. self.pos_embed = nn.Parameter(
  225. torch.from_numpy(pos_embed_arr).requires_grad_(False)
  226. )
  227. self.apply(self._init_weights)
  228. def forward(
  229. self,
  230. x: torch.Tensor,
  231. tgt_sizes: Optional[torch.Tensor] = None,
  232. attn_mask: Optional[torch.Tensor] = None,
  233. ) -> torch.Tensor:
  234. if tgt_sizes is None:
  235. tgt_sizes = int(math.sqrt(x.size(1)))
  236. if self.adaptive:
  237. pos_embed_arr = get_2d_sincos_pos_embed(
  238. self.embed_dim, tgt_sizes, version=(2, 0)
  239. )
  240. pos_embed = torch.from_numpy(pos_embed_arr).to(
  241. device=x.device, dtype=x.dtype
  242. )
  243. else:
  244. pos_embed = get_abs_pos(self.pos_embed, tgt_sizes).to(
  245. device=x.device, dtype=x.dtype
  246. )
  247. x, _ = self.kv_proj(x)
  248. x = self.ln_kv(x).permute(1, 0, 2)
  249. N = x.shape[1]
  250. q = self.ln_q(self.query)
  251. out = self.attn(
  252. self._repeat(q, N) + self.pos_embed.unsqueeze(1),
  253. x + pos_embed.unsqueeze(1),
  254. x,
  255. attn_mask=attn_mask,
  256. )[0]
  257. x = out.permute(1, 0, 2)
  258. if self.do_post_projection:
  259. x = self.ln_post(x)
  260. x = x @ self.proj
  261. return x