model.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from encoder.params_model import *
  2. from encoder.params_data import *
  3. from scipy.interpolate import interp1d
  4. from sklearn.metrics import roc_curve
  5. from torch.nn.utils import clip_grad_norm_
  6. from scipy.optimize import brentq
  7. from torch import nn
  8. import numpy as np
  9. import torch
  10. class SpeakerEncoder(nn.Module):
  11. def __init__(self, device, loss_device):
  12. super().__init__()
  13. self.loss_device = loss_device
  14. # Network defition
  15. self.lstm = nn.LSTM(input_size=mel_n_channels,
  16. hidden_size=model_hidden_size,
  17. num_layers=model_num_layers,
  18. batch_first=True).to(device)
  19. self.linear = nn.Linear(in_features=model_hidden_size,
  20. out_features=model_embedding_size).to(device)
  21. self.relu = torch.nn.ReLU().to(device)
  22. # Cosine similarity scaling (with fixed initial parameter values)
  23. self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
  24. self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
  25. # Loss
  26. self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
  27. def do_gradient_ops(self):
  28. # Gradient scale
  29. self.similarity_weight.grad *= 0.01
  30. self.similarity_bias.grad *= 0.01
  31. # Gradient clipping
  32. clip_grad_norm_(self.parameters(), 3, norm_type=2)
  33. def forward(self, utterances, hidden_init=None):
  34. """
  35. Computes the embeddings of a batch of utterance spectrograms.
  36. :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
  37. (batch_size, n_frames, n_channels)
  38. :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
  39. batch_size, hidden_size). Will default to a tensor of zeros if None.
  40. :return: the embeddings as a tensor of shape (batch_size, embedding_size)
  41. """
  42. # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
  43. # and the final cell state.
  44. out, (hidden, cell) = self.lstm(utterances, hidden_init)
  45. # We take only the hidden state of the last layer
  46. embeds_raw = self.relu(self.linear(hidden[-1]))
  47. # L2-normalize it
  48. embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
  49. return embeds
  50. def similarity_matrix(self, embeds):
  51. """
  52. Computes the similarity matrix according the section 2.1 of GE2E.
  53. :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
  54. utterances_per_speaker, embedding_size)
  55. :return: the similarity matrix as a tensor of shape (speakers_per_batch,
  56. utterances_per_speaker, speakers_per_batch)
  57. """
  58. speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
  59. # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
  60. centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
  61. centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
  62. # Exclusive centroids (1 per utterance)
  63. centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
  64. centroids_excl /= (utterances_per_speaker - 1)
  65. centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
  66. # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
  67. # product of these vectors (which is just an element-wise multiplication reduced by a sum).
  68. # We vectorize the computation for efficiency.
  69. sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
  70. speakers_per_batch).to(self.loss_device)
  71. mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
  72. for j in range(speakers_per_batch):
  73. mask = np.where(mask_matrix[j])[0]
  74. sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
  75. sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
  76. ## Even more vectorized version (slower maybe because of transpose)
  77. # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
  78. # ).to(self.loss_device)
  79. # eye = np.eye(speakers_per_batch, dtype=np.int)
  80. # mask = np.where(1 - eye)
  81. # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
  82. # mask = np.where(eye)
  83. # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
  84. # sim_matrix2 = sim_matrix2.transpose(1, 2)
  85. sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
  86. return sim_matrix
  87. def loss(self, embeds):
  88. """
  89. Computes the softmax loss according the section 2.1 of GE2E.
  90. :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
  91. utterances_per_speaker, embedding_size)
  92. :return: the loss and the EER for this batch of embeddings.
  93. """
  94. speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
  95. # Loss
  96. sim_matrix = self.similarity_matrix(embeds)
  97. sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
  98. speakers_per_batch))
  99. ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
  100. target = torch.from_numpy(ground_truth).long().to(self.loss_device)
  101. loss = self.loss_fn(sim_matrix, target)
  102. # EER (not backpropagated)
  103. with torch.no_grad():
  104. inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
  105. labels = np.array([inv_argmax(i) for i in ground_truth])
  106. preds = sim_matrix.detach().cpu().numpy()
  107. # Snippet from https://yangcha.github.io/EER-ROC/
  108. fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
  109. eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
  110. return loss, eer