__init__.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import sys
  2. import traceback
  3. from pathlib import Path
  4. from time import perf_counter as timer
  5. import numpy as np
  6. import torch
  7. from encoder import inference as encoder
  8. from synthesizer.inference import Synthesizer
  9. from toolbox.ui import UI
  10. from toolbox.utterance import Utterance
  11. from vocoder import inference as vocoder
  12. # Use this directory structure for your datasets, or modify it to fit your needs
  13. recognized_datasets = [
  14. "LibriSpeech/dev-clean",
  15. "LibriSpeech/dev-other",
  16. "LibriSpeech/test-clean",
  17. "LibriSpeech/test-other",
  18. "LibriSpeech/train-clean-100",
  19. "LibriSpeech/train-clean-360",
  20. "LibriSpeech/train-other-500",
  21. "LibriTTS/dev-clean",
  22. "LibriTTS/dev-other",
  23. "LibriTTS/test-clean",
  24. "LibriTTS/test-other",
  25. "LibriTTS/train-clean-100",
  26. "LibriTTS/train-clean-360",
  27. "LibriTTS/train-other-500",
  28. "LJSpeech-1.1",
  29. "VoxCeleb1/wav",
  30. "VoxCeleb1/test_wav",
  31. "VoxCeleb2/dev/aac",
  32. "VoxCeleb2/test/aac",
  33. "VCTK-Corpus/wav48",
  34. ]
  35. # Maximum of generated wavs to keep on memory
  36. MAX_WAVS = 15
  37. class Toolbox:
  38. def __init__(self, datasets_root: Path, models_dir: Path, seed: int=None):
  39. sys.excepthook = self.excepthook
  40. self.datasets_root = datasets_root
  41. self.utterances = set()
  42. self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
  43. self.synthesizer = None # type: Synthesizer
  44. self.current_wav = None
  45. self.waves_list = []
  46. self.waves_count = 0
  47. self.waves_namelist = []
  48. # Check for webrtcvad (enables removal of silences in vocoder output)
  49. try:
  50. import webrtcvad
  51. self.trim_silences = True
  52. except:
  53. self.trim_silences = False
  54. # Initialize the events and the interface
  55. self.ui = UI()
  56. self.reset_ui(models_dir, seed)
  57. self.setup_events()
  58. self.ui.start()
  59. def excepthook(self, exc_type, exc_value, exc_tb):
  60. traceback.print_exception(exc_type, exc_value, exc_tb)
  61. self.ui.log("Exception: %s" % exc_value)
  62. def setup_events(self):
  63. # Dataset, speaker and utterance selection
  64. self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
  65. random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
  66. recognized_datasets,
  67. level)
  68. self.ui.random_dataset_button.clicked.connect(random_func(0))
  69. self.ui.random_speaker_button.clicked.connect(random_func(1))
  70. self.ui.random_utterance_button.clicked.connect(random_func(2))
  71. self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
  72. self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
  73. # Model selection
  74. self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
  75. def func():
  76. self.synthesizer = None
  77. self.ui.synthesizer_box.currentIndexChanged.connect(func)
  78. self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
  79. # Utterance selection
  80. func = lambda: self.load_from_browser(self.ui.browse_file())
  81. self.ui.browser_browse_button.clicked.connect(func)
  82. func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
  83. self.ui.utterance_history.currentIndexChanged.connect(func)
  84. func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
  85. self.ui.play_button.clicked.connect(func)
  86. self.ui.stop_button.clicked.connect(self.ui.stop)
  87. self.ui.record_button.clicked.connect(self.record)
  88. #Audio
  89. self.ui.setup_audio_devices(Synthesizer.sample_rate)
  90. #Wav playback & save
  91. func = lambda: self.replay_last_wav()
  92. self.ui.replay_wav_button.clicked.connect(func)
  93. func = lambda: self.export_current_wave()
  94. self.ui.export_wav_button.clicked.connect(func)
  95. self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
  96. # Generation
  97. func = lambda: self.synthesize() or self.vocode()
  98. self.ui.generate_button.clicked.connect(func)
  99. self.ui.synthesize_button.clicked.connect(self.synthesize)
  100. self.ui.vocode_button.clicked.connect(self.vocode)
  101. self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
  102. # UMAP legend
  103. self.ui.clear_button.clicked.connect(self.clear_utterances)
  104. def set_current_wav(self, index):
  105. self.current_wav = self.waves_list[index]
  106. def export_current_wave(self):
  107. self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
  108. def replay_last_wav(self):
  109. self.ui.play(self.current_wav, Synthesizer.sample_rate)
  110. def reset_ui(self, models_dir: Path, seed: int=None):
  111. self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
  112. self.ui.populate_models(models_dir)
  113. self.ui.populate_gen_options(seed, self.trim_silences)
  114. def load_from_browser(self, fpath=None):
  115. if fpath is None:
  116. fpath = Path(self.datasets_root,
  117. self.ui.current_dataset_name,
  118. self.ui.current_speaker_name,
  119. self.ui.current_utterance_name)
  120. name = str(fpath.relative_to(self.datasets_root))
  121. speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
  122. # Select the next utterance
  123. if self.ui.auto_next_checkbox.isChecked():
  124. self.ui.browser_select_next()
  125. elif fpath == "":
  126. return
  127. else:
  128. name = fpath.name
  129. speaker_name = fpath.parent.name
  130. # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
  131. # playback, so as to have a fair comparison with the generated audio
  132. wav = Synthesizer.load_preprocess_wav(fpath)
  133. self.ui.log("Loaded %s" % name)
  134. self.add_real_utterance(wav, name, speaker_name)
  135. def record(self):
  136. wav = self.ui.record_one(encoder.sampling_rate, 5)
  137. if wav is None:
  138. return
  139. self.ui.play(wav, encoder.sampling_rate)
  140. speaker_name = "user01"
  141. name = speaker_name + "_rec_%05d" % np.random.randint(100000)
  142. self.add_real_utterance(wav, name, speaker_name)
  143. def add_real_utterance(self, wav, name, speaker_name):
  144. # Compute the mel spectrogram
  145. spec = Synthesizer.make_spectrogram(wav)
  146. self.ui.draw_spec(spec, "current")
  147. # Compute the embedding
  148. if not encoder.is_loaded():
  149. self.init_encoder()
  150. encoder_wav = encoder.preprocess_wav(wav)
  151. embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
  152. # Add the utterance
  153. utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
  154. self.utterances.add(utterance)
  155. self.ui.register_utterance(utterance)
  156. # Plot it
  157. self.ui.draw_embed(embed, name, "current")
  158. self.ui.draw_umap_projections(self.utterances)
  159. def clear_utterances(self):
  160. self.utterances.clear()
  161. self.ui.draw_umap_projections(self.utterances)
  162. def synthesize(self):
  163. self.ui.log("Generating the mel spectrogram...")
  164. self.ui.set_loading(1)
  165. # Update the synthesizer random seed
  166. if self.ui.random_seed_checkbox.isChecked():
  167. seed = int(self.ui.seed_textbox.text())
  168. self.ui.populate_gen_options(seed, self.trim_silences)
  169. else:
  170. seed = None
  171. if seed is not None:
  172. torch.manual_seed(seed)
  173. # Synthesize the spectrogram
  174. if self.synthesizer is None or seed is not None:
  175. self.init_synthesizer()
  176. texts = self.ui.text_prompt.toPlainText().split("\n")
  177. embed = self.ui.selected_utterance.embed
  178. embeds = [embed] * len(texts)
  179. specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
  180. breaks = [spec.shape[1] for spec in specs]
  181. spec = np.concatenate(specs, axis=1)
  182. self.ui.draw_spec(spec, "generated")
  183. self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
  184. self.ui.set_loading(0)
  185. def vocode(self):
  186. speaker_name, spec, breaks, _ = self.current_generated
  187. assert spec is not None
  188. # Initialize the vocoder model and make it determinstic, if user provides a seed
  189. if self.ui.random_seed_checkbox.isChecked():
  190. seed = int(self.ui.seed_textbox.text())
  191. self.ui.populate_gen_options(seed, self.trim_silences)
  192. else:
  193. seed = None
  194. if seed is not None:
  195. torch.manual_seed(seed)
  196. # Synthesize the waveform
  197. if not vocoder.is_loaded() or seed is not None:
  198. self.init_vocoder()
  199. def vocoder_progress(i, seq_len, b_size, gen_rate):
  200. real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
  201. line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
  202. % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
  203. self.ui.log(line, "overwrite")
  204. self.ui.set_loading(i, seq_len)
  205. if self.ui.current_vocoder_fpath is not None:
  206. self.ui.log("")
  207. wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
  208. else:
  209. self.ui.log("Waveform generation with Griffin-Lim... ")
  210. wav = Synthesizer.griffin_lim(spec)
  211. self.ui.set_loading(0)
  212. self.ui.log(" Done!", "append")
  213. # Add breaks
  214. b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
  215. b_starts = np.concatenate(([0], b_ends[:-1]))
  216. wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
  217. breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
  218. wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
  219. # Trim excessive silences
  220. if self.ui.trim_silences_checkbox.isChecked():
  221. wav = encoder.preprocess_wav(wav)
  222. # Play it
  223. wav = wav / np.abs(wav).max() * 0.97
  224. self.ui.play(wav, Synthesizer.sample_rate)
  225. # Name it (history displayed in combobox)
  226. # TODO better naming for the combobox items?
  227. wav_name = str(self.waves_count + 1)
  228. #Update waves combobox
  229. self.waves_count += 1
  230. if self.waves_count > MAX_WAVS:
  231. self.waves_list.pop()
  232. self.waves_namelist.pop()
  233. self.waves_list.insert(0, wav)
  234. self.waves_namelist.insert(0, wav_name)
  235. self.ui.waves_cb.disconnect()
  236. self.ui.waves_cb_model.setStringList(self.waves_namelist)
  237. self.ui.waves_cb.setCurrentIndex(0)
  238. self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
  239. # Update current wav
  240. self.set_current_wav(0)
  241. #Enable replay and save buttons:
  242. self.ui.replay_wav_button.setDisabled(False)
  243. self.ui.export_wav_button.setDisabled(False)
  244. # Compute the embedding
  245. # TODO: this is problematic with different sampling rates, gotta fix it
  246. if not encoder.is_loaded():
  247. self.init_encoder()
  248. encoder_wav = encoder.preprocess_wav(wav)
  249. embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
  250. # Add the utterance
  251. name = speaker_name + "_gen_%05d" % np.random.randint(100000)
  252. utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
  253. self.utterances.add(utterance)
  254. # Plot it
  255. self.ui.draw_embed(embed, name, "generated")
  256. self.ui.draw_umap_projections(self.utterances)
  257. def init_encoder(self):
  258. model_fpath = self.ui.current_encoder_fpath
  259. self.ui.log("Loading the encoder %s... " % model_fpath)
  260. self.ui.set_loading(1)
  261. start = timer()
  262. encoder.load_model(model_fpath)
  263. self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
  264. self.ui.set_loading(0)
  265. def init_synthesizer(self):
  266. model_fpath = self.ui.current_synthesizer_fpath
  267. self.ui.log("Loading the synthesizer %s... " % model_fpath)
  268. self.ui.set_loading(1)
  269. start = timer()
  270. self.synthesizer = Synthesizer(model_fpath)
  271. self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
  272. self.ui.set_loading(0)
  273. def init_vocoder(self):
  274. model_fpath = self.ui.current_vocoder_fpath
  275. # Case of Griffin-lim
  276. if model_fpath is None:
  277. return
  278. self.ui.log("Loading the vocoder %s... " % model_fpath)
  279. self.ui.set_loading(1)
  280. start = timer()
  281. vocoder.load_model(model_fpath)
  282. self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
  283. self.ui.set_loading(0)
  284. def update_seed_textbox(self):
  285. self.ui.update_seed_textbox()