123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- import functools
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import generators.flow_util as flow_util
- from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
- class FaceGenerator(nn.Module):
- def __init__(
- self,
- mapping_net,
- warpping_net,
- editing_net,
- common
- ):
- super(FaceGenerator, self).__init__()
- self.mapping_net = MappingNet(**mapping_net)
- self.warpping_net = WarpingNet(**warpping_net, **common)
- self.editing_net = EditingNet(**editing_net, **common)
-
- def forward(
- self,
- input_image,
- driving_source,
- stage=None
- ):
- if stage == 'warp':
- descriptor = self.mapping_net(driving_source)
- output = self.warpping_net(input_image, descriptor)
- else:
- descriptor = self.mapping_net(driving_source)
- output = self.warpping_net(input_image, descriptor)
- output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
- return output
- class MappingNet(nn.Module):
- def __init__(self, coeff_nc, descriptor_nc, layer):
- super( MappingNet, self).__init__()
- self.layer = layer
- nonlinearity = nn.LeakyReLU(0.1)
- self.first = nn.Sequential(
- torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
- for i in range(layer):
- net = nn.Sequential(nonlinearity,
- torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
- setattr(self, 'encoder' + str(i), net)
- self.pooling = nn.AdaptiveAvgPool1d(1)
- self.output_nc = descriptor_nc
- def forward(self, input_3dmm):
- out = self.first(input_3dmm)
- for i in range(self.layer):
- model = getattr(self, 'encoder' + str(i))
- out = model(out) + out[:,:,3:-3]
- out = self.pooling(out)
- return out
- class WarpingNet(nn.Module):
- def __init__(
- self,
- image_nc,
- descriptor_nc,
- base_nc,
- max_nc,
- encoder_layer,
- decoder_layer,
- use_spect
- ):
- super( WarpingNet, self).__init__()
- nonlinearity = nn.LeakyReLU(0.1)
- norm_layer = functools.partial(LayerNorm2d, affine=True)
- kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
- self.descriptor_nc = descriptor_nc
- self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
- max_nc, encoder_layer, decoder_layer, **kwargs)
- self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
- nonlinearity,
- nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
- self.pool = nn.AdaptiveAvgPool2d(1)
- def forward(self, input_image, descriptor):
- final_output={}
- output = self.hourglass(input_image, descriptor)
- final_output['flow_field'] = self.flow_out(output)
- deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
- final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
- return final_output
- class EditingNet(nn.Module):
- def __init__(
- self,
- image_nc,
- descriptor_nc,
- layer,
- base_nc,
- max_nc,
- num_res_blocks,
- use_spect):
- super(EditingNet, self).__init__()
- nonlinearity = nn.LeakyReLU(0.1)
- norm_layer = functools.partial(LayerNorm2d, affine=True)
- kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
- self.descriptor_nc = descriptor_nc
- # encoder part
- self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
- self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
- def forward(self, input_image, warp_image, descriptor):
- x = torch.cat([input_image, warp_image], 1)
- x = self.encoder(x)
- gen_image = self.decoder(x, descriptor)
- return gen_image
|