display.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import time
  2. import numpy as np
  3. import sys
  4. def progbar(i, n, size=16):
  5. done = (i * size) // n
  6. bar = ''
  7. for i in range(size):
  8. bar += '█' if i <= done else '░'
  9. return bar
  10. def stream(message) :
  11. try:
  12. sys.stdout.write("\r{%s}" % message)
  13. except:
  14. #Remove non-ASCII characters from message
  15. message = ''.join(i for i in message if ord(i)<128)
  16. sys.stdout.write("\r{%s}" % message)
  17. def simple_table(item_tuples) :
  18. border_pattern = '+---------------------------------------'
  19. whitespace = ' '
  20. headings, cells, = [], []
  21. for item in item_tuples :
  22. heading, cell = str(item[0]), str(item[1])
  23. pad_head = True if len(heading) < len(cell) else False
  24. pad = abs(len(heading) - len(cell))
  25. pad = whitespace[:pad]
  26. pad_left = pad[:len(pad)//2]
  27. pad_right = pad[len(pad)//2:]
  28. if pad_head :
  29. heading = pad_left + heading + pad_right
  30. else :
  31. cell = pad_left + cell + pad_right
  32. headings += [heading]
  33. cells += [cell]
  34. border, head, body = '', '', ''
  35. for i in range(len(item_tuples)) :
  36. temp_head = f'| {headings[i]} '
  37. temp_body = f'| {cells[i]} '
  38. border += border_pattern[:len(temp_head)]
  39. head += temp_head
  40. body += temp_body
  41. if i == len(item_tuples) - 1 :
  42. head += '|'
  43. body += '|'
  44. border += '+'
  45. print(border)
  46. print(head)
  47. print(border)
  48. print(body)
  49. print(border)
  50. print(' ')
  51. def time_since(started) :
  52. elapsed = time.time() - started
  53. m = int(elapsed // 60)
  54. s = int(elapsed % 60)
  55. if m >= 60 :
  56. h = int(m // 60)
  57. m = m % 60
  58. return f'{h}h {m}m {s}s'
  59. else :
  60. return f'{m}m {s}s'
  61. def save_attention(attn, path):
  62. import matplotlib.pyplot as plt
  63. fig = plt.figure(figsize=(12, 6))
  64. plt.imshow(attn.T, interpolation='nearest', aspect='auto')
  65. fig.savefig(f'{path}.png', bbox_inches='tight')
  66. plt.close(fig)
  67. def save_spectrogram(M, path, length=None):
  68. import matplotlib.pyplot as plt
  69. M = np.flip(M, axis=0)
  70. if length : M = M[:, :length]
  71. fig = plt.figure(figsize=(12, 6))
  72. plt.imshow(M, interpolation='nearest', aspect='auto')
  73. fig.savefig(f'{path}.png', bbox_inches='tight')
  74. plt.close(fig)
  75. def plot(array):
  76. import matplotlib.pyplot as plt
  77. fig = plt.figure(figsize=(30, 5))
  78. ax = fig.add_subplot(111)
  79. ax.xaxis.label.set_color('grey')
  80. ax.yaxis.label.set_color('grey')
  81. ax.xaxis.label.set_fontsize(23)
  82. ax.yaxis.label.set_fontsize(23)
  83. ax.tick_params(axis='x', colors='grey', labelsize=23)
  84. ax.tick_params(axis='y', colors='grey', labelsize=23)
  85. plt.plot(array)
  86. def plot_spec(M):
  87. import matplotlib.pyplot as plt
  88. M = np.flip(M, axis=0)
  89. plt.figure(figsize=(18,4))
  90. plt.imshow(M, interpolation='nearest', aspect='auto')
  91. plt.show()