plot.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import numpy as np
  2. def split_title_line(title_text, max_words=5):
  3. """
  4. A function that splits any string based on specific character
  5. (returning it with the string), with maximum number of words on it
  6. """
  7. seq = title_text.split()
  8. return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
  9. def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
  10. import matplotlib
  11. matplotlib.use("Agg")
  12. import matplotlib.pyplot as plt
  13. if max_len is not None:
  14. alignment = alignment[:, :max_len]
  15. fig = plt.figure(figsize=(8, 6))
  16. ax = fig.add_subplot(111)
  17. im = ax.imshow(
  18. alignment,
  19. aspect="auto",
  20. origin="lower",
  21. interpolation="none")
  22. fig.colorbar(im, ax=ax)
  23. xlabel = "Decoder timestep"
  24. if split_title:
  25. title = split_title_line(title)
  26. plt.xlabel(xlabel)
  27. plt.title(title)
  28. plt.ylabel("Encoder timestep")
  29. plt.tight_layout()
  30. plt.savefig(path, format="png")
  31. plt.close()
  32. def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
  33. import matplotlib
  34. matplotlib.use("Agg")
  35. import matplotlib.pyplot as plt
  36. if max_len is not None:
  37. target_spectrogram = target_spectrogram[:max_len]
  38. pred_spectrogram = pred_spectrogram[:max_len]
  39. if split_title:
  40. title = split_title_line(title)
  41. fig = plt.figure(figsize=(10, 8))
  42. # Set common labels
  43. fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
  44. #target spectrogram subplot
  45. if target_spectrogram is not None:
  46. ax1 = fig.add_subplot(311)
  47. ax2 = fig.add_subplot(312)
  48. if auto_aspect:
  49. im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
  50. else:
  51. im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
  52. ax1.set_title("Target Mel-Spectrogram")
  53. fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
  54. ax2.set_title("Predicted Mel-Spectrogram")
  55. else:
  56. ax2 = fig.add_subplot(211)
  57. if auto_aspect:
  58. im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
  59. else:
  60. im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
  61. fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
  62. plt.tight_layout()
  63. plt.savefig(path, format="png")
  64. plt.close()