save_waveform_plot.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import io
  2. import matplotlib
  3. import matplotlib.figure as mpl_fig
  4. from matplotlib import pyplot as plt
  5. import numpy as np
  6. matplotlib.use("agg")
  7. def plot_waveform(audio_array: np.ndarray):
  8. fig = plt.figure(figsize=(10, 3))
  9. plt.style.use("dark_background")
  10. plt.plot(audio_array, color="orange")
  11. plt.axis("off")
  12. return fig
  13. def figure_to_image(fig: mpl_fig.Figure):
  14. with io.BytesIO() as buff:
  15. fig.savefig(buff, format="raw")
  16. buff.seek(0)
  17. data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
  18. w, h = fig.canvas.get_width_height()
  19. return data.reshape((int(h), int(w), -1))
  20. def plot_waveform_as_image(audio_array: np.ndarray):
  21. fig = plot_waveform(audio_array)
  22. plt.close()
  23. return figure_to_image(fig)
  24. def middleware_save_waveform_plot(audio_array: np.ndarray, filename_png: str):
  25. # fig = plt.figure(figsize=(10, 3))
  26. # plt.style.use("dark_background")
  27. # plt.plot(audio_array, color="orange")
  28. # plt.axis("off")
  29. # plt.savefig(filename_png)
  30. # plt.close()
  31. fig = plot_waveform(audio_array)
  32. plt.savefig(filename_png)
  33. plt.close()
  34. return figure_to_image(fig)
  35. if __name__ == "__main__":
  36. print("Testing save_waveform_plot.py")
  37. audio_array = np.random.rand(100)
  38. filename_png = "test.png"
  39. data = middleware_save_waveform_plot(audio_array, filename_png)
  40. print(data)
  41. print("Testing save_waveform_plot.py done")