visualizations.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from datetime import datetime
  2. from time import perf_counter as timer
  3. import numpy as np
  4. import umap
  5. import visdom
  6. from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
  7. colormap = np.array([
  8. [76, 255, 0],
  9. [0, 127, 70],
  10. [255, 0, 0],
  11. [255, 217, 38],
  12. [0, 135, 255],
  13. [165, 0, 165],
  14. [255, 167, 255],
  15. [0, 255, 255],
  16. [255, 96, 38],
  17. [142, 76, 0],
  18. [33, 0, 127],
  19. [0, 0, 0],
  20. [183, 183, 183],
  21. ], dtype=np.float) / 255
  22. class Visualizations:
  23. def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
  24. # Tracking data
  25. self.last_update_timestamp = timer()
  26. self.update_every = update_every
  27. self.step_times = []
  28. self.losses = []
  29. self.eers = []
  30. print("Updating the visualizations every %d steps." % update_every)
  31. # If visdom is disabled TODO: use a better paradigm for that
  32. self.disabled = disabled
  33. if self.disabled:
  34. return
  35. # Set the environment name
  36. now = str(datetime.now().strftime("%d-%m %Hh%M"))
  37. if env_name is None:
  38. self.env_name = now
  39. else:
  40. self.env_name = "%s (%s)" % (env_name, now)
  41. # Connect to visdom and open the corresponding window in the browser
  42. try:
  43. self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
  44. except ConnectionError:
  45. raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
  46. "start it.")
  47. # webbrowser.open("http://localhost:8097/env/" + self.env_name)
  48. # Create the windows
  49. self.loss_win = None
  50. self.eer_win = None
  51. # self.lr_win = None
  52. self.implementation_win = None
  53. self.projection_win = None
  54. self.implementation_string = ""
  55. def log_params(self):
  56. if self.disabled:
  57. return
  58. from encoder import params_data
  59. from encoder import params_model
  60. param_string = "<b>Model parameters</b>:<br>"
  61. for param_name in (p for p in dir(params_model) if not p.startswith("__")):
  62. value = getattr(params_model, param_name)
  63. param_string += "\t%s: %s<br>" % (param_name, value)
  64. param_string += "<b>Data parameters</b>:<br>"
  65. for param_name in (p for p in dir(params_data) if not p.startswith("__")):
  66. value = getattr(params_data, param_name)
  67. param_string += "\t%s: %s<br>" % (param_name, value)
  68. self.vis.text(param_string, opts={"title": "Parameters"})
  69. def log_dataset(self, dataset: SpeakerVerificationDataset):
  70. if self.disabled:
  71. return
  72. dataset_string = ""
  73. dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
  74. dataset_string += "\n" + dataset.get_logs()
  75. dataset_string = dataset_string.replace("\n", "<br>")
  76. self.vis.text(dataset_string, opts={"title": "Dataset"})
  77. def log_implementation(self, params):
  78. if self.disabled:
  79. return
  80. implementation_string = ""
  81. for param, value in params.items():
  82. implementation_string += "<b>%s</b>: %s\n" % (param, value)
  83. implementation_string = implementation_string.replace("\n", "<br>")
  84. self.implementation_string = implementation_string
  85. self.implementation_win = self.vis.text(
  86. implementation_string,
  87. opts={"title": "Training implementation"}
  88. )
  89. def update(self, loss, eer, step):
  90. # Update the tracking data
  91. now = timer()
  92. self.step_times.append(1000 * (now - self.last_update_timestamp))
  93. self.last_update_timestamp = now
  94. self.losses.append(loss)
  95. self.eers.append(eer)
  96. print(".", end="")
  97. # Update the plots every <update_every> steps
  98. if step % self.update_every != 0:
  99. return
  100. time_string = "Step time: mean: %5dms std: %5dms" % \
  101. (int(np.mean(self.step_times)), int(np.std(self.step_times)))
  102. print("\nStep %6d Loss: %.4f EER: %.4f %s" %
  103. (step, np.mean(self.losses), np.mean(self.eers), time_string))
  104. if not self.disabled:
  105. self.loss_win = self.vis.line(
  106. [np.mean(self.losses)],
  107. [step],
  108. win=self.loss_win,
  109. update="append" if self.loss_win else None,
  110. opts=dict(
  111. legend=["Avg. loss"],
  112. xlabel="Step",
  113. ylabel="Loss",
  114. title="Loss",
  115. )
  116. )
  117. self.eer_win = self.vis.line(
  118. [np.mean(self.eers)],
  119. [step],
  120. win=self.eer_win,
  121. update="append" if self.eer_win else None,
  122. opts=dict(
  123. legend=["Avg. EER"],
  124. xlabel="Step",
  125. ylabel="EER",
  126. title="Equal error rate"
  127. )
  128. )
  129. if self.implementation_win is not None:
  130. self.vis.text(
  131. self.implementation_string + ("<b>%s</b>" % time_string),
  132. win=self.implementation_win,
  133. opts={"title": "Training implementation"},
  134. )
  135. # Reset the tracking
  136. self.losses.clear()
  137. self.eers.clear()
  138. self.step_times.clear()
  139. def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
  140. import matplotlib.pyplot as plt
  141. max_speakers = min(max_speakers, len(colormap))
  142. embeds = embeds[:max_speakers * utterances_per_speaker]
  143. n_speakers = len(embeds) // utterances_per_speaker
  144. ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
  145. colors = [colormap[i] for i in ground_truth]
  146. reducer = umap.UMAP()
  147. projected = reducer.fit_transform(embeds)
  148. plt.scatter(projected[:, 0], projected[:, 1], c=colors)
  149. plt.gca().set_aspect("equal", "datalim")
  150. plt.title("UMAP projection (step %d)" % step)
  151. if not self.disabled:
  152. self.projection_win = self.vis.matplot(plt, win=self.projection_win)
  153. if out_fpath is not None:
  154. plt.savefig(out_fpath)
  155. plt.clf()
  156. def save(self):
  157. if not self.disabled:
  158. self.vis.save([self.env_name])