2
0

lpips.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
  2. import torch
  3. import torch.nn as nn
  4. from torchvision import models
  5. from collections import namedtuple
  6. from taming.util import get_ckpt_path
  7. class LPIPS(nn.Module):
  8. # Learned perceptual metric
  9. def __init__(self, use_dropout=True):
  10. super().__init__()
  11. self.scaling_layer = ScalingLayer()
  12. self.chns = [64, 128, 256, 512, 512] # vg16 features
  13. self.net = vgg16(pretrained=True, requires_grad=False)
  14. self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
  15. self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
  16. self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
  17. self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
  18. self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
  19. self.load_from_pretrained()
  20. for param in self.parameters():
  21. param.requires_grad = False
  22. def load_from_pretrained(self, name="vgg_lpips"):
  23. ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
  24. self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
  25. print("loaded pretrained LPIPS loss from {}".format(ckpt))
  26. @classmethod
  27. def from_pretrained(cls, name="vgg_lpips"):
  28. if name != "vgg_lpips":
  29. raise NotImplementedError
  30. model = cls()
  31. ckpt = get_ckpt_path(name)
  32. model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
  33. return model
  34. def forward(self, input, target):
  35. in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
  36. outs0, outs1 = self.net(in0_input), self.net(in1_input)
  37. feats0, feats1, diffs = {}, {}, {}
  38. lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
  39. for kk in range(len(self.chns)):
  40. feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
  41. diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
  42. res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
  43. val = res[0]
  44. for l in range(1, len(self.chns)):
  45. val += res[l]
  46. return val
  47. class ScalingLayer(nn.Module):
  48. def __init__(self):
  49. super(ScalingLayer, self).__init__()
  50. self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
  51. self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
  52. def forward(self, inp):
  53. return (inp - self.shift) / self.scale
  54. class NetLinLayer(nn.Module):
  55. """ A single linear layer which does a 1x1 conv """
  56. def __init__(self, chn_in, chn_out=1, use_dropout=False):
  57. super(NetLinLayer, self).__init__()
  58. layers = [nn.Dropout(), ] if (use_dropout) else []
  59. layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
  60. self.model = nn.Sequential(*layers)
  61. class vgg16(torch.nn.Module):
  62. def __init__(self, requires_grad=False, pretrained=True):
  63. super(vgg16, self).__init__()
  64. vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
  65. self.slice1 = torch.nn.Sequential()
  66. self.slice2 = torch.nn.Sequential()
  67. self.slice3 = torch.nn.Sequential()
  68. self.slice4 = torch.nn.Sequential()
  69. self.slice5 = torch.nn.Sequential()
  70. self.N_slices = 5
  71. for x in range(4):
  72. self.slice1.add_module(str(x), vgg_pretrained_features[x])
  73. for x in range(4, 9):
  74. self.slice2.add_module(str(x), vgg_pretrained_features[x])
  75. for x in range(9, 16):
  76. self.slice3.add_module(str(x), vgg_pretrained_features[x])
  77. for x in range(16, 23):
  78. self.slice4.add_module(str(x), vgg_pretrained_features[x])
  79. for x in range(23, 30):
  80. self.slice5.add_module(str(x), vgg_pretrained_features[x])
  81. if not requires_grad:
  82. for param in self.parameters():
  83. param.requires_grad = False
  84. def forward(self, X):
  85. h = self.slice1(X)
  86. h_relu1_2 = h
  87. h = self.slice2(h)
  88. h_relu2_2 = h
  89. h = self.slice3(h)
  90. h_relu3_3 = h
  91. h = self.slice4(h)
  92. h_relu4_3 = h
  93. h = self.slice5(h)
  94. h_relu5_3 = h
  95. vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
  96. out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
  97. return out
  98. def normalize_tensor(x,eps=1e-10):
  99. norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
  100. return x/(norm_factor+eps)
  101. def spatial_average(x, keepdim=True):
  102. return x.mean([2,3],keepdim=keepdim)