face_model.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import functools
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import generators.flow_util as flow_util
  7. from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
  8. class FaceGenerator(nn.Module):
  9. def __init__(
  10. self,
  11. mapping_net,
  12. warpping_net,
  13. editing_net,
  14. common
  15. ):
  16. super(FaceGenerator, self).__init__()
  17. self.mapping_net = MappingNet(**mapping_net)
  18. self.warpping_net = WarpingNet(**warpping_net, **common)
  19. self.editing_net = EditingNet(**editing_net, **common)
  20. def forward(
  21. self,
  22. input_image,
  23. driving_source,
  24. stage=None
  25. ):
  26. if stage == 'warp':
  27. descriptor = self.mapping_net(driving_source)
  28. output = self.warpping_net(input_image, descriptor)
  29. else:
  30. descriptor = self.mapping_net(driving_source)
  31. output = self.warpping_net(input_image, descriptor)
  32. output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
  33. return output
  34. class MappingNet(nn.Module):
  35. def __init__(self, coeff_nc, descriptor_nc, layer):
  36. super( MappingNet, self).__init__()
  37. self.layer = layer
  38. nonlinearity = nn.LeakyReLU(0.1)
  39. self.first = nn.Sequential(
  40. torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
  41. for i in range(layer):
  42. net = nn.Sequential(nonlinearity,
  43. torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
  44. setattr(self, 'encoder' + str(i), net)
  45. self.pooling = nn.AdaptiveAvgPool1d(1)
  46. self.output_nc = descriptor_nc
  47. def forward(self, input_3dmm):
  48. out = self.first(input_3dmm)
  49. for i in range(self.layer):
  50. model = getattr(self, 'encoder' + str(i))
  51. out = model(out) + out[:,:,3:-3]
  52. out = self.pooling(out)
  53. return out
  54. class WarpingNet(nn.Module):
  55. def __init__(
  56. self,
  57. image_nc,
  58. descriptor_nc,
  59. base_nc,
  60. max_nc,
  61. encoder_layer,
  62. decoder_layer,
  63. use_spect
  64. ):
  65. super( WarpingNet, self).__init__()
  66. nonlinearity = nn.LeakyReLU(0.1)
  67. norm_layer = functools.partial(LayerNorm2d, affine=True)
  68. kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
  69. self.descriptor_nc = descriptor_nc
  70. self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
  71. max_nc, encoder_layer, decoder_layer, **kwargs)
  72. self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
  73. nonlinearity,
  74. nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
  75. self.pool = nn.AdaptiveAvgPool2d(1)
  76. def forward(self, input_image, descriptor):
  77. final_output={}
  78. output = self.hourglass(input_image, descriptor)
  79. final_output['flow_field'] = self.flow_out(output)
  80. deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
  81. final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
  82. return final_output
  83. class EditingNet(nn.Module):
  84. def __init__(
  85. self,
  86. image_nc,
  87. descriptor_nc,
  88. layer,
  89. base_nc,
  90. max_nc,
  91. num_res_blocks,
  92. use_spect):
  93. super(EditingNet, self).__init__()
  94. nonlinearity = nn.LeakyReLU(0.1)
  95. norm_layer = functools.partial(LayerNorm2d, affine=True)
  96. kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
  97. self.descriptor_nc = descriptor_nc
  98. # encoder part
  99. self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
  100. self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
  101. def forward(self, input_image, warp_image, descriptor):
  102. x = torch.cat([input_image, warp_image], 1)
  103. x = self.encoder(x)
  104. gen_image = self.decoder(x, descriptor)
  105. return gen_image