patch_embed.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py
  2. # But we use nn.Linear instead of Conv2d and it's about 8x faster.
  3. from functools import partial
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from torch import _assert
  7. from torch.nn.modules.utils import _pair
  8. try:
  9. from flash_attn.ops.fused_dense import FusedDense
  10. except ImportError:
  11. FusedDense = None
  12. class PatchEmbed(nn.Module):
  13. """2D Image to Patch Embedding"""
  14. def __init__(
  15. self,
  16. img_size=224,
  17. patch_size=16,
  18. in_chans=3,
  19. embed_dim=768,
  20. norm_layer=None,
  21. flatten=True,
  22. bias=True,
  23. fused_bias_fc=False,
  24. ):
  25. super().__init__()
  26. img_size = _pair(img_size)
  27. patch_size = _pair(patch_size)
  28. self.img_size = img_size
  29. self.patch_size = patch_size
  30. self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  31. self.num_patches = self.grid_size[0] * self.grid_size[1]
  32. self.flatten = flatten
  33. if fused_bias_fc and FusedDense is None:
  34. raise ImportError("fused_dense is not installed")
  35. linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
  36. self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
  37. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  38. def forward(self, x):
  39. _, _, H, W = x.shape
  40. _assert(
  41. H == self.img_size[0],
  42. f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
  43. )
  44. _assert(
  45. W == self.img_size[1],
  46. f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
  47. )
  48. x = self.proj(
  49. rearrange(
  50. x,
  51. "b c (h p1) (w p2) -> b h w (c p1 p2)",
  52. p1=self.patch_size[0],
  53. p2=self.patch_size[1],
  54. )
  55. )
  56. if self.flatten:
  57. x = rearrange(x, "b h w c -> b (h w) c")
  58. x = self.norm(x)
  59. return x