123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- import torch
- import torch.nn as nn
- def count_params(model):
- total_params = sum(p.numel() for p in model.parameters())
- return total_params
- class ActNorm(nn.Module):
- def __init__(self, num_features, logdet=False, affine=True,
- allow_reverse_init=False):
- assert affine
- super().__init__()
- self.logdet = logdet
- self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
- self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
- self.allow_reverse_init = allow_reverse_init
- self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
- def initialize(self, input):
- with torch.no_grad():
- flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
- mean = (
- flatten.mean(1)
- .unsqueeze(1)
- .unsqueeze(2)
- .unsqueeze(3)
- .permute(1, 0, 2, 3)
- )
- std = (
- flatten.std(1)
- .unsqueeze(1)
- .unsqueeze(2)
- .unsqueeze(3)
- .permute(1, 0, 2, 3)
- )
- self.loc.data.copy_(-mean)
- self.scale.data.copy_(1 / (std + 1e-6))
- def forward(self, input, reverse=False):
- if reverse:
- return self.reverse(input)
- if len(input.shape) == 2:
- input = input[:,:,None,None]
- squeeze = True
- else:
- squeeze = False
- _, _, height, width = input.shape
- if self.training and self.initialized.item() == 0:
- self.initialize(input)
- self.initialized.fill_(1)
- h = self.scale * (input + self.loc)
- if squeeze:
- h = h.squeeze(-1).squeeze(-1)
- if self.logdet:
- log_abs = torch.log(torch.abs(self.scale))
- logdet = height*width*torch.sum(log_abs)
- logdet = logdet * torch.ones(input.shape[0]).to(input)
- return h, logdet
- return h
- def reverse(self, output):
- if self.training and self.initialized.item() == 0:
- if not self.allow_reverse_init:
- raise RuntimeError(
- "Initializing ActNorm in reverse direction is "
- "disabled by default. Use allow_reverse_init=True to enable."
- )
- else:
- self.initialize(output)
- self.initialized.fill_(1)
- if len(output.shape) == 2:
- output = output[:,:,None,None]
- squeeze = True
- else:
- squeeze = False
- h = output / self.scale - self.loc
- if squeeze:
- h = h.squeeze(-1).squeeze(-1)
- return h
- class AbstractEncoder(nn.Module):
- def __init__(self):
- super().__init__()
- def encode(self, *args, **kwargs):
- raise NotImplementedError
- class Labelator(AbstractEncoder):
- """Net2Net Interface for Class-Conditional Model"""
- def __init__(self, n_classes, quantize_interface=True):
- super().__init__()
- self.n_classes = n_classes
- self.quantize_interface = quantize_interface
- def encode(self, c):
- c = c[:,None]
- if self.quantize_interface:
- return c, None, [None, None, c.long()]
- return c
- class SOSProvider(AbstractEncoder):
- # for unconditional training
- def __init__(self, sos_token, quantize_interface=True):
- super().__init__()
- self.sos_token = sos_token
- self.quantize_interface = quantize_interface
- def encode(self, x):
- # get batch size from data and replicate sos_token
- c = torch.ones(x.shape[0], 1)*self.sos_token
- c = c.long().to(x.device)
- if self.quantize_interface:
- return c, None, [None, None, c]
- return c
|