losses.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import math
  2. import torch
  3. from torch.nn import functional as F
  4. def feature_loss(fmap_r, fmap_g):
  5. loss = 0
  6. for dr, dg in zip(fmap_r, fmap_g):
  7. for rl, gl in zip(dr, dg):
  8. rl = rl.float().detach()
  9. gl = gl.float()
  10. loss += torch.mean(torch.abs(rl - gl))
  11. return loss * 2
  12. def discriminator_loss(disc_real_outputs, disc_generated_outputs):
  13. loss = 0
  14. r_losses = []
  15. g_losses = []
  16. for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
  17. dr = dr.float()
  18. dg = dg.float()
  19. r_loss = torch.mean((1 - dr) ** 2)
  20. g_loss = torch.mean(dg**2)
  21. loss += r_loss + g_loss
  22. r_losses.append(r_loss.item())
  23. g_losses.append(g_loss.item())
  24. return loss, r_losses, g_losses
  25. def generator_loss(disc_outputs):
  26. loss = 0
  27. gen_losses = []
  28. for dg in disc_outputs:
  29. dg = dg.float()
  30. l = torch.mean((1 - dg) ** 2)
  31. gen_losses.append(l)
  32. loss += l
  33. return loss, gen_losses
  34. def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
  35. """
  36. z_p, logs_q: [b, h, t_t]
  37. m_p, logs_p: [b, h, t_t]
  38. """
  39. z_p = z_p.float()
  40. logs_q = logs_q.float()
  41. m_p = m_p.float()
  42. logs_p = logs_p.float()
  43. z_mask = z_mask.float()
  44. kl = logs_p - logs_q - 0.5
  45. kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
  46. kl = torch.sum(kl * z_mask)
  47. l = kl / torch.sum(z_mask)
  48. return l
  49. def mle_loss(z, m, logs, logdet, mask):
  50. l = torch.sum(logs) + 0.5 * torch.sum(
  51. torch.exp(-2 * logs) * ((z - m) ** 2)
  52. ) # neg normal likelihood w/o the constant term
  53. l = l - torch.sum(logdet) # log jacobian determinant
  54. l = l / torch.sum(
  55. torch.ones_like(z) * mask
  56. ) # averaging across batch, channel and time axes
  57. l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
  58. return l