patch_embed.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py
  2. # But we use nn.Linear instead of Conv2d and it's about 8x faster.
  3. from functools import partial
  4. import torch.nn as nn
  5. from torch import _assert
  6. from torch.nn.modules.utils import _pair
  7. from einops import rearrange
  8. try:
  9. from flash_attn.ops.fused_dense import FusedDense
  10. except ImportError:
  11. FusedDense = None
  12. class PatchEmbed(nn.Module):
  13. """ 2D Image to Patch Embedding
  14. """
  15. def __init__(
  16. self,
  17. img_size=224,
  18. patch_size=16,
  19. in_chans=3,
  20. embed_dim=768,
  21. norm_layer=None,
  22. flatten=True,
  23. bias=True,
  24. fused_bias_fc=False,
  25. ):
  26. super().__init__()
  27. img_size = _pair(img_size)
  28. patch_size = _pair(patch_size)
  29. self.img_size = img_size
  30. self.patch_size = patch_size
  31. self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  32. self.num_patches = self.grid_size[0] * self.grid_size[1]
  33. self.flatten = flatten
  34. if fused_bias_fc and FusedDense is None:
  35. raise ImportError('fused_dense is not installed')
  36. linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
  37. self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
  38. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  39. def forward(self, x):
  40. _, _, H, W = x.shape
  41. _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
  42. _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
  43. x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)',
  44. p1=self.patch_size[0], p2=self.patch_size[1]))
  45. if self.flatten:
  46. x = rearrange(x, 'b h w c -> b (h w) c')
  47. x = self.norm(x)
  48. return x