distribution.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. def log_sum_exp(x):
  5. """ numerically stable log_sum_exp implementation that prevents overflow """
  6. # TF ordering
  7. axis = len(x.size()) - 1
  8. m, _ = torch.max(x, dim=axis)
  9. m2, _ = torch.max(x, dim=axis, keepdim=True)
  10. return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
  11. # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
  12. def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
  13. log_scale_min=None, reduce=True):
  14. if log_scale_min is None:
  15. log_scale_min = float(np.log(1e-14))
  16. y_hat = y_hat.permute(0,2,1)
  17. assert y_hat.dim() == 3
  18. assert y_hat.size(1) % 3 == 0
  19. nr_mix = y_hat.size(1) // 3
  20. # (B x T x C)
  21. y_hat = y_hat.transpose(1, 2)
  22. # unpack parameters. (B, T, num_mixtures) x 3
  23. logit_probs = y_hat[:, :, :nr_mix]
  24. means = y_hat[:, :, nr_mix:2 * nr_mix]
  25. log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
  26. # B x T x 1 -> B x T x num_mixtures
  27. y = y.expand_as(means)
  28. centered_y = y - means
  29. inv_stdv = torch.exp(-log_scales)
  30. plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
  31. cdf_plus = torch.sigmoid(plus_in)
  32. min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
  33. cdf_min = torch.sigmoid(min_in)
  34. # log probability for edge case of 0 (before scaling)
  35. # equivalent: torch.log(F.sigmoid(plus_in))
  36. log_cdf_plus = plus_in - F.softplus(plus_in)
  37. # log probability for edge case of 255 (before scaling)
  38. # equivalent: (1 - F.sigmoid(min_in)).log()
  39. log_one_minus_cdf_min = -F.softplus(min_in)
  40. # probability for all other cases
  41. cdf_delta = cdf_plus - cdf_min
  42. mid_in = inv_stdv * centered_y
  43. # log probability in the center of the bin, to be used in extreme cases
  44. # (not actually used in our code)
  45. log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
  46. # tf equivalent
  47. """
  48. log_probs = tf.where(x < -0.999, log_cdf_plus,
  49. tf.where(x > 0.999, log_one_minus_cdf_min,
  50. tf.where(cdf_delta > 1e-5,
  51. tf.log(tf.maximum(cdf_delta, 1e-12)),
  52. log_pdf_mid - np.log(127.5))))
  53. """
  54. # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
  55. # for num_classes=65536 case? 1e-7? not sure..
  56. inner_inner_cond = (cdf_delta > 1e-5).float()
  57. inner_inner_out = inner_inner_cond * \
  58. torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
  59. (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
  60. inner_cond = (y > 0.999).float()
  61. inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
  62. cond = (y < -0.999).float()
  63. log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
  64. log_probs = log_probs + F.log_softmax(logit_probs, -1)
  65. if reduce:
  66. return -torch.mean(log_sum_exp(log_probs))
  67. else:
  68. return -log_sum_exp(log_probs).unsqueeze(-1)
  69. def sample_from_discretized_mix_logistic(y, log_scale_min=None):
  70. """
  71. Sample from discretized mixture of logistic distributions
  72. Args:
  73. y (Tensor): B x C x T
  74. log_scale_min (float): Log scale minimum value
  75. Returns:
  76. Tensor: sample in range of [-1, 1].
  77. """
  78. if log_scale_min is None:
  79. log_scale_min = float(np.log(1e-14))
  80. assert y.size(1) % 3 == 0
  81. nr_mix = y.size(1) // 3
  82. # B x T x C
  83. y = y.transpose(1, 2)
  84. logit_probs = y[:, :, :nr_mix]
  85. # sample mixture indicator from softmax
  86. temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
  87. temp = logit_probs.data - torch.log(- torch.log(temp))
  88. _, argmax = temp.max(dim=-1)
  89. # (B, T) -> (B, T, nr_mix)
  90. one_hot = to_one_hot(argmax, nr_mix)
  91. # select logistic parameters
  92. means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
  93. log_scales = torch.clamp(torch.sum(
  94. y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
  95. # sample from logistic & clip to interval
  96. # we don't actually round to the nearest 8bit value when sampling
  97. u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
  98. x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
  99. x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
  100. return x
  101. def to_one_hot(tensor, n, fill_with=1.):
  102. # we perform one hot encore with respect to the last axis
  103. one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
  104. if tensor.is_cuda:
  105. one_hot = one_hot.cuda()
  106. one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
  107. return one_hot