2
0

util.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import torch
  2. import torch.nn as nn
  3. def count_params(model):
  4. total_params = sum(p.numel() for p in model.parameters())
  5. return total_params
  6. class ActNorm(nn.Module):
  7. def __init__(self, num_features, logdet=False, affine=True,
  8. allow_reverse_init=False):
  9. assert affine
  10. super().__init__()
  11. self.logdet = logdet
  12. self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
  13. self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
  14. self.allow_reverse_init = allow_reverse_init
  15. self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
  16. def initialize(self, input):
  17. with torch.no_grad():
  18. flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
  19. mean = (
  20. flatten.mean(1)
  21. .unsqueeze(1)
  22. .unsqueeze(2)
  23. .unsqueeze(3)
  24. .permute(1, 0, 2, 3)
  25. )
  26. std = (
  27. flatten.std(1)
  28. .unsqueeze(1)
  29. .unsqueeze(2)
  30. .unsqueeze(3)
  31. .permute(1, 0, 2, 3)
  32. )
  33. self.loc.data.copy_(-mean)
  34. self.scale.data.copy_(1 / (std + 1e-6))
  35. def forward(self, input, reverse=False):
  36. if reverse:
  37. return self.reverse(input)
  38. if len(input.shape) == 2:
  39. input = input[:,:,None,None]
  40. squeeze = True
  41. else:
  42. squeeze = False
  43. _, _, height, width = input.shape
  44. if self.training and self.initialized.item() == 0:
  45. self.initialize(input)
  46. self.initialized.fill_(1)
  47. h = self.scale * (input + self.loc)
  48. if squeeze:
  49. h = h.squeeze(-1).squeeze(-1)
  50. if self.logdet:
  51. log_abs = torch.log(torch.abs(self.scale))
  52. logdet = height*width*torch.sum(log_abs)
  53. logdet = logdet * torch.ones(input.shape[0]).to(input)
  54. return h, logdet
  55. return h
  56. def reverse(self, output):
  57. if self.training and self.initialized.item() == 0:
  58. if not self.allow_reverse_init:
  59. raise RuntimeError(
  60. "Initializing ActNorm in reverse direction is "
  61. "disabled by default. Use allow_reverse_init=True to enable."
  62. )
  63. else:
  64. self.initialize(output)
  65. self.initialized.fill_(1)
  66. if len(output.shape) == 2:
  67. output = output[:,:,None,None]
  68. squeeze = True
  69. else:
  70. squeeze = False
  71. h = output / self.scale - self.loc
  72. if squeeze:
  73. h = h.squeeze(-1).squeeze(-1)
  74. return h
  75. class AbstractEncoder(nn.Module):
  76. def __init__(self):
  77. super().__init__()
  78. def encode(self, *args, **kwargs):
  79. raise NotImplementedError
  80. class Labelator(AbstractEncoder):
  81. """Net2Net Interface for Class-Conditional Model"""
  82. def __init__(self, n_classes, quantize_interface=True):
  83. super().__init__()
  84. self.n_classes = n_classes
  85. self.quantize_interface = quantize_interface
  86. def encode(self, c):
  87. c = c[:,None]
  88. if self.quantize_interface:
  89. return c, None, [None, None, c.long()]
  90. return c
  91. class SOSProvider(AbstractEncoder):
  92. # for unconditional training
  93. def __init__(self, sos_token, quantize_interface=True):
  94. super().__init__()
  95. self.sos_token = sos_token
  96. self.quantize_interface = quantize_interface
  97. def encode(self, x):
  98. # get batch size from data and replicate sos_token
  99. c = torch.ones(x.shape[0], 1)*self.sos_token
  100. c = c.long().to(x.device)
  101. if self.quantize_interface:
  102. return c, None, [None, None, c]
  103. return c