vit.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # Copyright (c) 2022, Tri Dao.
  2. # Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  3. import math
  4. import re
  5. from collections import OrderedDict
  6. from copy import deepcopy
  7. from functools import partial
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from einops import rearrange
  12. from timm.models.helpers import named_apply
  13. from torch.nn.init import trunc_normal_
  14. from torchvision.ops import StochasticDepth
  15. from flash_attn.layers.patch_embed import PatchEmbed
  16. from flash_attn.modules.block import Block
  17. from flash_attn.modules.mha import MHA
  18. from flash_attn.modules.mlp import FusedMLP, Mlp
  19. try:
  20. from flash_attn.ops.triton.layer_norm import layer_norm_fn
  21. except ImportError:
  22. layer_norm_fn = None
  23. def create_mixer_cls(
  24. num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False
  25. ):
  26. mixer_cls = partial(
  27. MHA,
  28. num_heads=num_heads,
  29. cross_attn=cross_attn,
  30. qkv_proj_bias=qkv_bias,
  31. dropout=attn_drop,
  32. fused_bias_fc=fused_bias_fc,
  33. use_flash_attn=use_flash_attn,
  34. )
  35. return mixer_cls
  36. def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
  37. inner_dim = int(embed_dim * mlp_ratio)
  38. if not fused_mlp:
  39. mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
  40. else:
  41. mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
  42. return mlp_cls
  43. def create_block(
  44. embed_dim,
  45. num_heads,
  46. mlp_ratio,
  47. qkv_bias,
  48. drop_rate,
  49. attn_drop_rate,
  50. drop_path1,
  51. drop_path2,
  52. norm_layer,
  53. act_layer,
  54. use_flash_attn,
  55. fused_bias_fc,
  56. fused_mlp,
  57. fused_dropout_add_ln,
  58. layer_idx=None,
  59. n_layer=None,
  60. last_layer_subset=False,
  61. ):
  62. mixer_cls = create_mixer_cls(
  63. num_heads,
  64. qkv_bias,
  65. attn_drop_rate,
  66. use_flash_attn,
  67. fused_bias_fc,
  68. cross_attn=(last_layer_subset and layer_idx == n_layer - 1),
  69. )
  70. mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
  71. # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
  72. block = Block(
  73. embed_dim,
  74. mixer_cls,
  75. mlp_cls,
  76. norm_cls=norm_layer,
  77. prenorm=True,
  78. resid_dropout1=drop_rate,
  79. resid_dropout2=drop_rate,
  80. drop_path1=drop_path1,
  81. drop_path2=drop_path2,
  82. fused_dropout_add_ln=fused_dropout_add_ln,
  83. residual_in_fp32=True,
  84. )
  85. return block
  86. class VisionTransformer(nn.Module):
  87. """Vision Transformer
  88. A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
  89. - https://arxiv.org/abs/2010.11929
  90. """
  91. def __init__(
  92. self,
  93. img_size=224,
  94. patch_size=16,
  95. in_chans=3,
  96. num_classes=1000,
  97. global_pool="token",
  98. embed_dim=768,
  99. depth=12,
  100. num_heads=12,
  101. mlp_ratio=4.0,
  102. qkv_bias=True,
  103. init_values=None,
  104. class_token=True,
  105. no_embed_class=False,
  106. pre_norm=False,
  107. fc_norm=None,
  108. drop_rate=0.0,
  109. attn_drop_rate=0.0,
  110. drop_path_rate=0.0,
  111. weight_init="",
  112. embed_layer=PatchEmbed,
  113. norm_layer=None,
  114. act_layer=None,
  115. use_flash_attn=False,
  116. fused_bias_fc=False,
  117. fused_mlp=False,
  118. fused_dropout_add_ln=False,
  119. ):
  120. """
  121. Args:
  122. img_size (int, tuple): input image size
  123. patch_size (int, tuple): patch size
  124. in_chans (int): number of input channels
  125. num_classes (int): number of classes for classification head
  126. global_pool (str): type of global pooling for final sequence (default: 'token')
  127. embed_dim (int): embedding dimension
  128. depth (int): depth of transformer
  129. num_heads (int): number of attention heads
  130. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  131. qkv_bias (bool): enable bias for qkv if True
  132. init_values: (float): layer-scale init values
  133. class_token (bool): use class token
  134. fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
  135. drop_rate (float): dropout rate
  136. attn_drop_rate (float): attention dropout rate
  137. drop_path_rate (float): stochastic depth rate
  138. weight_init (str): weight init scheme
  139. embed_layer (nn.Module): patch embedding layer
  140. norm_layer: (nn.Module): normalization layer
  141. act_layer: (nn.Module): MLP activation layer
  142. """
  143. super().__init__()
  144. assert global_pool == "token", "Only support pooling with CLS token"
  145. assert class_token
  146. assert init_values is None, "LayerScale is not supported yet"
  147. assert weight_init == ""
  148. assert fc_norm is None
  149. # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
  150. assert not pre_norm
  151. use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
  152. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  153. act_layer = act_layer or nn.GELU
  154. self.num_classes = num_classes
  155. self.global_pool = global_pool
  156. self.num_features = (
  157. self.embed_dim
  158. ) = embed_dim # num_features for consistency with other models
  159. self.num_prefix_tokens = 1 if class_token else 0
  160. self.no_embed_class = no_embed_class
  161. patch_embed_extra_kwargs = (
  162. {"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {}
  163. )
  164. self.patch_embed = embed_layer(
  165. img_size=img_size,
  166. patch_size=patch_size,
  167. in_chans=in_chans,
  168. embed_dim=embed_dim,
  169. bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
  170. **patch_embed_extra_kwargs,
  171. )
  172. num_patches = self.patch_embed.num_patches
  173. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
  174. embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
  175. self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
  176. dpr = [
  177. x.item() for x in torch.linspace(0, drop_path_rate, depth)
  178. ] # stochastic depth decay rule
  179. # We change the order of dropout, residual and layer norm:
  180. # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
  181. # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
  182. # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
  183. # nn.Dropout probabilities are changed.
  184. # This is for performance reason: we can fuse dropout + add + layer_norm.
  185. self.blocks = nn.ModuleList(
  186. [
  187. create_block(
  188. embed_dim,
  189. num_heads,
  190. mlp_ratio,
  191. qkv_bias,
  192. drop_rate,
  193. attn_drop_rate,
  194. drop_path1=dpr[i - 1] if i > 0 else 0.0,
  195. drop_path2=dpr[i],
  196. norm_layer=norm_layer,
  197. act_layer=act_layer,
  198. use_flash_attn=use_flash_attn,
  199. fused_bias_fc=fused_bias_fc,
  200. fused_mlp=fused_mlp,
  201. fused_dropout_add_ln=fused_dropout_add_ln,
  202. layer_idx=i,
  203. n_layer=depth,
  204. last_layer_subset=(global_pool == "token"),
  205. )
  206. for i in range(depth)
  207. ]
  208. )
  209. self.dropout = nn.Dropout(p=drop_rate)
  210. self.drop_path = StochasticDepth(p=dpr[-1], mode="row")
  211. self.norm = norm_layer(embed_dim)
  212. self.fused_dropout_add_ln = fused_dropout_add_ln
  213. if self.fused_dropout_add_ln and layer_norm_fn is None:
  214. raise ImportError("Triton is not installed")
  215. # Classifier Head
  216. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  217. self.init_weights(weight_init)
  218. def init_weights(self, mode=""):
  219. assert mode == ""
  220. trunc_normal_(self.pos_embed, std=0.02)
  221. if self.cls_token is not None:
  222. nn.init.normal_(self.cls_token, std=1e-6)
  223. named_apply(init_weights_vit_timm, self)
  224. def _init_weights(self, m):
  225. # this fn left here for compat with downstream users
  226. init_weights_vit_timm(m)
  227. @torch.jit.ignore
  228. def no_weight_decay(self):
  229. return {"pos_embed", "cls_token"}
  230. def _pos_embed(self, x):
  231. if self.no_embed_class:
  232. # deit-3, updated JAX (big vision)
  233. # position embedding does not overlap with class token, add then concat
  234. x = x + self.pos_embed
  235. if self.cls_token is not None:
  236. x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
  237. else:
  238. # original timm, JAX, and deit vit impl
  239. # pos_embed has entry for class token, concat then add
  240. if self.cls_token is not None:
  241. x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
  242. x = x + self.pos_embed
  243. return x
  244. def forward_features(self, x, all_tokens=True):
  245. """
  246. If all_tokens==False and self.global_pool == 'token', we only return the features for the
  247. cls token.
  248. """
  249. x = self.patch_embed(x)
  250. hidden_states = self._pos_embed(x)
  251. residual = None
  252. if self.global_pool != "token" or all_tokens:
  253. # if True:
  254. for block in self.blocks:
  255. hidden_states, residual = block(hidden_states, residual)
  256. else:
  257. for block in self.blocks[:-1]:
  258. hidden_states, residual = block(hidden_states, residual)
  259. # For the last layer, we only want the 1st token of the output. So we do cross-attention
  260. # where the query is the 1st token and the key/value is the whole sequence.
  261. hidden_states, residual = self.blocks[-1](
  262. hidden_states, residual, mixer_subset=slice(0, 1)
  263. )
  264. if not self.fused_dropout_add_ln:
  265. residual = self.drop_path(self.dropout(hidden_states)) + residual
  266. hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
  267. else:
  268. if self.drop_path.p == 0 or not self.training:
  269. rowscale = None
  270. else:
  271. rowscale = self.drop_path(
  272. torch.ones(
  273. hidden_states.shape[:-1],
  274. device=hidden_states.device,
  275. dtype=hidden_states.dtype,
  276. )
  277. )
  278. # Set prenorm=False here since we don't need to the residual
  279. hidden_states = layer_norm_fn(
  280. hidden_states,
  281. self.norm.weight,
  282. self.norm.bias,
  283. residual=residual,
  284. eps=self.norm.eps,
  285. dropout_p=self.dropout.p if self.training else 0.0,
  286. rowscale=rowscale,
  287. prenorm=False,
  288. )
  289. return hidden_states
  290. def forward_head(self, x, pre_logits: bool = False):
  291. if self.global_pool:
  292. x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0]
  293. return x if pre_logits else self.head(x)
  294. def forward(self, x):
  295. x = self.forward_features(x, all_tokens=False)
  296. x = self.forward_head(x)
  297. return x
  298. def load_state_dict(self, state_dict, strict=True):
  299. patch_embed_weight = state_dict["patch_embed.proj.weight"]
  300. if patch_embed_weight.dim() == 4:
  301. # convert from Conv2d to Linear
  302. state_dict["patch_embed.proj.weight"] = rearrange(
  303. patch_embed_weight, "o c h w -> o (c h w)"
  304. )
  305. def key_mapping_attn(key):
  306. key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key)
  307. key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key)
  308. return key
  309. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  310. n_layer = len(self.blocks)
  311. # Convert from Wqkv to Wq and Wkv for cross attention (last layer)
  312. if (
  313. self.blocks[-1].mixer.cross_attn
  314. and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict
  315. ):
  316. Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight")
  317. bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias")
  318. state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim]
  319. state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :]
  320. state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim]
  321. state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :]
  322. return super().load_state_dict(state_dict, strict=strict)
  323. def init_weights_vit_timm(module: nn.Module, name: str = ""):
  324. """ViT weight initialization, original timm impl (for reproducibility)"""
  325. if isinstance(module, nn.Linear):
  326. trunc_normal_(module.weight, std=0.02)
  327. if module.bias is not None:
  328. nn.init.zeros_(module.bias)
  329. elif hasattr(module, "init_weights"):
  330. module.init_weights()
  331. def vit_base_patch16_224(pretrained=False, **kwargs):
  332. """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  333. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  334. """
  335. assert not pretrained
  336. model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
  337. model = VisionTransformer(**model_kwargs)
  338. return model