base_function.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import sys
  2. import math
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from torch.autograd import Function
  7. from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
  8. class LayerNorm2d(nn.Module):
  9. def __init__(self, n_out, affine=True):
  10. super(LayerNorm2d, self).__init__()
  11. self.n_out = n_out
  12. self.affine = affine
  13. if self.affine:
  14. self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
  15. self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
  16. def forward(self, x):
  17. normalized_shape = x.size()[1:]
  18. if self.affine:
  19. return F.layer_norm(x, normalized_shape, \
  20. self.weight.expand(normalized_shape),
  21. self.bias.expand(normalized_shape))
  22. else:
  23. return F.layer_norm(x, normalized_shape)
  24. class ADAINHourglass(nn.Module):
  25. def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
  26. super(ADAINHourglass, self).__init__()
  27. self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
  28. self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
  29. self.output_nc = self.decoder.output_nc
  30. def forward(self, x, z):
  31. return self.decoder(self.encoder(x, z), z)
  32. class ADAINEncoder(nn.Module):
  33. def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
  34. super(ADAINEncoder, self).__init__()
  35. self.layers = layers
  36. self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
  37. for i in range(layers):
  38. in_channels = min(ngf * (2**i), img_f)
  39. out_channels = min(ngf *(2**(i+1)), img_f)
  40. model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
  41. setattr(self, 'encoder' + str(i), model)
  42. self.output_nc = out_channels
  43. def forward(self, x, z):
  44. out = self.input_layer(x)
  45. out_list = [out]
  46. for i in range(self.layers):
  47. model = getattr(self, 'encoder' + str(i))
  48. out = model(out, z)
  49. out_list.append(out)
  50. return out_list
  51. class ADAINDecoder(nn.Module):
  52. """docstring for ADAINDecoder"""
  53. def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
  54. nonlinearity=nn.LeakyReLU(), use_spect=False):
  55. super(ADAINDecoder, self).__init__()
  56. self.encoder_layers = encoder_layers
  57. self.decoder_layers = decoder_layers
  58. self.skip_connect = skip_connect
  59. use_transpose = True
  60. for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
  61. in_channels = min(ngf * (2**(i+1)), img_f)
  62. in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
  63. out_channels = min(ngf * (2**i), img_f)
  64. model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
  65. setattr(self, 'decoder' + str(i), model)
  66. self.output_nc = out_channels*2 if self.skip_connect else out_channels
  67. def forward(self, x, z):
  68. out = x.pop() if self.skip_connect else x
  69. for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
  70. model = getattr(self, 'decoder' + str(i))
  71. out = model(out, z)
  72. out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
  73. return out
  74. class ADAINEncoderBlock(nn.Module):
  75. def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
  76. super(ADAINEncoderBlock, self).__init__()
  77. kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
  78. kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
  79. self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
  80. self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
  81. self.norm_0 = ADAIN(input_nc, feature_nc)
  82. self.norm_1 = ADAIN(output_nc, feature_nc)
  83. self.actvn = nonlinearity
  84. def forward(self, x, z):
  85. x = self.conv_0(self.actvn(self.norm_0(x, z)))
  86. x = self.conv_1(self.actvn(self.norm_1(x, z)))
  87. return x
  88. class ADAINDecoderBlock(nn.Module):
  89. def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
  90. super(ADAINDecoderBlock, self).__init__()
  91. # Attributes
  92. self.actvn = nonlinearity
  93. hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
  94. kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
  95. if use_transpose:
  96. kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
  97. else:
  98. kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
  99. # create conv layers
  100. self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
  101. if use_transpose:
  102. self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
  103. self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
  104. else:
  105. self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
  106. nn.Upsample(scale_factor=2))
  107. self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
  108. nn.Upsample(scale_factor=2))
  109. # define normalization layers
  110. self.norm_0 = ADAIN(input_nc, feature_nc)
  111. self.norm_1 = ADAIN(hidden_nc, feature_nc)
  112. self.norm_s = ADAIN(input_nc, feature_nc)
  113. def forward(self, x, z):
  114. x_s = self.shortcut(x, z)
  115. dx = self.conv_0(self.actvn(self.norm_0(x, z)))
  116. dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
  117. out = x_s + dx
  118. return out
  119. def shortcut(self, x, z):
  120. x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
  121. return x_s
  122. def spectral_norm(module, use_spect=True):
  123. """use spectral normal layer to stable the training process"""
  124. if use_spect:
  125. return SpectralNorm(module)
  126. else:
  127. return module
  128. class ADAIN(nn.Module):
  129. def __init__(self, norm_nc, feature_nc):
  130. super().__init__()
  131. self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
  132. nhidden = 128
  133. use_bias=True
  134. self.mlp_shared = nn.Sequential(
  135. nn.Linear(feature_nc, nhidden, bias=use_bias),
  136. nn.ReLU()
  137. )
  138. self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
  139. self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
  140. def forward(self, x, feature):
  141. # Part 1. generate parameter-free normalized activations
  142. normalized = self.param_free_norm(x)
  143. # Part 2. produce scaling and bias conditioned on feature
  144. feature = feature.view(feature.size(0), -1)
  145. actv = self.mlp_shared(feature)
  146. gamma = self.mlp_gamma(actv)
  147. beta = self.mlp_beta(actv)
  148. # apply scale and bias
  149. gamma = gamma.view(*gamma.size()[:2], 1,1)
  150. beta = beta.view(*beta.size()[:2], 1,1)
  151. out = normalized * (1 + gamma) + beta
  152. return out
  153. class FineEncoder(nn.Module):
  154. """docstring for Encoder"""
  155. def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  156. super(FineEncoder, self).__init__()
  157. self.layers = layers
  158. self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
  159. for i in range(layers):
  160. in_channels = min(ngf*(2**i), img_f)
  161. out_channels = min(ngf*(2**(i+1)), img_f)
  162. model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
  163. setattr(self, 'down' + str(i), model)
  164. self.output_nc = out_channels
  165. def forward(self, x):
  166. x = self.first(x)
  167. out=[x]
  168. for i in range(self.layers):
  169. model = getattr(self, 'down'+str(i))
  170. x = model(x)
  171. out.append(x)
  172. return out
  173. class FineDecoder(nn.Module):
  174. """docstring for FineDecoder"""
  175. def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  176. super(FineDecoder, self).__init__()
  177. self.layers = layers
  178. for i in range(layers)[::-1]:
  179. in_channels = min(ngf*(2**(i+1)), img_f)
  180. out_channels = min(ngf*(2**i), img_f)
  181. up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
  182. res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
  183. jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
  184. setattr(self, 'up' + str(i), up)
  185. setattr(self, 'res' + str(i), res)
  186. setattr(self, 'jump' + str(i), jump)
  187. self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
  188. self.output_nc = out_channels
  189. def forward(self, x, z):
  190. out = x.pop()
  191. for i in range(self.layers)[::-1]:
  192. res_model = getattr(self, 'res' + str(i))
  193. up_model = getattr(self, 'up' + str(i))
  194. jump_model = getattr(self, 'jump' + str(i))
  195. out = res_model(out, z)
  196. out = up_model(out)
  197. out = jump_model(x.pop()) + out
  198. out_image = self.final(out)
  199. return out_image
  200. class FirstBlock2d(nn.Module):
  201. """
  202. Downsampling block for use in encoder.
  203. """
  204. def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  205. super(FirstBlock2d, self).__init__()
  206. kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
  207. conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
  208. if type(norm_layer) == type(None):
  209. self.model = nn.Sequential(conv, nonlinearity)
  210. else:
  211. self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
  212. def forward(self, x):
  213. out = self.model(x)
  214. return out
  215. class DownBlock2d(nn.Module):
  216. def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  217. super(DownBlock2d, self).__init__()
  218. kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
  219. conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
  220. pool = nn.AvgPool2d(kernel_size=(2, 2))
  221. if type(norm_layer) == type(None):
  222. self.model = nn.Sequential(conv, nonlinearity, pool)
  223. else:
  224. self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
  225. def forward(self, x):
  226. out = self.model(x)
  227. return out
  228. class UpBlock2d(nn.Module):
  229. def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  230. super(UpBlock2d, self).__init__()
  231. kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
  232. conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
  233. if type(norm_layer) == type(None):
  234. self.model = nn.Sequential(conv, nonlinearity)
  235. else:
  236. self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
  237. def forward(self, x):
  238. out = self.model(F.interpolate(x, scale_factor=2))
  239. return out
  240. class FineADAINResBlocks(nn.Module):
  241. def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  242. super(FineADAINResBlocks, self).__init__()
  243. self.num_block = num_block
  244. for i in range(num_block):
  245. model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
  246. setattr(self, 'res'+str(i), model)
  247. def forward(self, x, z):
  248. for i in range(self.num_block):
  249. model = getattr(self, 'res'+str(i))
  250. x = model(x, z)
  251. return x
  252. class Jump(nn.Module):
  253. def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  254. super(Jump, self).__init__()
  255. kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
  256. conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
  257. if type(norm_layer) == type(None):
  258. self.model = nn.Sequential(conv, nonlinearity)
  259. else:
  260. self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
  261. def forward(self, x):
  262. out = self.model(x)
  263. return out
  264. class FineADAINResBlock2d(nn.Module):
  265. """
  266. Define an Residual block for different types
  267. """
  268. def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
  269. super(FineADAINResBlock2d, self).__init__()
  270. kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
  271. self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
  272. self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
  273. self.norm1 = ADAIN(input_nc, feature_nc)
  274. self.norm2 = ADAIN(input_nc, feature_nc)
  275. self.actvn = nonlinearity
  276. def forward(self, x, z):
  277. dx = self.actvn(self.norm1(self.conv1(x), z))
  278. dx = self.norm2(self.conv2(x), z)
  279. out = dx + x
  280. return out
  281. class FinalBlock2d(nn.Module):
  282. """
  283. Define the output layer
  284. """
  285. def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
  286. super(FinalBlock2d, self).__init__()
  287. kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
  288. conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
  289. if tanh_or_sigmoid == 'sigmoid':
  290. out_nonlinearity = nn.Sigmoid()
  291. else:
  292. out_nonlinearity = nn.Tanh()
  293. self.model = nn.Sequential(conv, out_nonlinearity)
  294. def forward(self, x):
  295. out = self.model(x)
  296. return out