train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. from datetime import datetime
  2. from functools import partial
  3. from pathlib import Path
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import optim
  7. from torch.utils.data import DataLoader
  8. from synthesizer import audio
  9. from synthesizer.models.tacotron import Tacotron
  10. from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
  11. from synthesizer.utils import ValueWindow, data_parallel_workaround
  12. from synthesizer.utils.plot import plot_spectrogram
  13. from synthesizer.utils.symbols import symbols
  14. from synthesizer.utils.text import sequence_to_text
  15. from vocoder.display import *
  16. def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
  17. def time_string():
  18. return datetime.now().strftime("%Y-%m-%d %H:%M")
  19. def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool,
  20. hparams):
  21. models_dir.mkdir(exist_ok=True)
  22. model_dir = models_dir.joinpath(run_id)
  23. plot_dir = model_dir.joinpath("plots")
  24. wav_dir = model_dir.joinpath("wavs")
  25. mel_output_dir = model_dir.joinpath("mel-spectrograms")
  26. meta_folder = model_dir.joinpath("metas")
  27. model_dir.mkdir(exist_ok=True)
  28. plot_dir.mkdir(exist_ok=True)
  29. wav_dir.mkdir(exist_ok=True)
  30. mel_output_dir.mkdir(exist_ok=True)
  31. meta_folder.mkdir(exist_ok=True)
  32. weights_fpath = model_dir / f"synthesizer.pt"
  33. metadata_fpath = syn_dir.joinpath("train.txt")
  34. print("Checkpoint path: {}".format(weights_fpath))
  35. print("Loading training data from: {}".format(metadata_fpath))
  36. print("Using model: Tacotron")
  37. # Bookkeeping
  38. time_window = ValueWindow(100)
  39. loss_window = ValueWindow(100)
  40. # From WaveRNN/train_tacotron.py
  41. if torch.cuda.is_available():
  42. device = torch.device("cuda")
  43. for session in hparams.tts_schedule:
  44. _, _, _, batch_size = session
  45. if batch_size % torch.cuda.device_count() != 0:
  46. raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
  47. else:
  48. device = torch.device("cpu")
  49. print("Using device:", device)
  50. # Instantiate Tacotron Model
  51. print("\nInitialising Tacotron Model...\n")
  52. model = Tacotron(embed_dims=hparams.tts_embed_dims,
  53. num_chars=len(symbols),
  54. encoder_dims=hparams.tts_encoder_dims,
  55. decoder_dims=hparams.tts_decoder_dims,
  56. n_mels=hparams.num_mels,
  57. fft_bins=hparams.num_mels,
  58. postnet_dims=hparams.tts_postnet_dims,
  59. encoder_K=hparams.tts_encoder_K,
  60. lstm_dims=hparams.tts_lstm_dims,
  61. postnet_K=hparams.tts_postnet_K,
  62. num_highways=hparams.tts_num_highways,
  63. dropout=hparams.tts_dropout,
  64. stop_threshold=hparams.tts_stop_threshold,
  65. speaker_embedding_size=hparams.speaker_embedding_size).to(device)
  66. # Initialize the optimizer
  67. optimizer = optim.Adam(model.parameters())
  68. # Load the weights
  69. if force_restart or not weights_fpath.exists():
  70. print("\nStarting the training of Tacotron from scratch\n")
  71. model.save(weights_fpath)
  72. # Embeddings metadata
  73. char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
  74. with open(char_embedding_fpath, "w", encoding="utf-8") as f:
  75. for symbol in symbols:
  76. if symbol == " ":
  77. symbol = "\\s" # For visual purposes, swap space with \s
  78. f.write("{}\n".format(symbol))
  79. else:
  80. print("\nLoading weights at %s" % weights_fpath)
  81. model.load(weights_fpath, optimizer)
  82. print("Tacotron weights loaded from step %d" % model.step)
  83. # Initialize the dataset
  84. metadata_fpath = syn_dir.joinpath("train.txt")
  85. mel_dir = syn_dir.joinpath("mels")
  86. embed_dir = syn_dir.joinpath("embeds")
  87. dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
  88. for i, session in enumerate(hparams.tts_schedule):
  89. current_step = model.get_step()
  90. r, lr, max_step, batch_size = session
  91. training_steps = max_step - current_step
  92. # Do we need to change to the next session?
  93. if current_step >= max_step:
  94. # Are there no further sessions than the current one?
  95. if i == len(hparams.tts_schedule) - 1:
  96. # We have completed training. Save the model and exit
  97. model.save(weights_fpath, optimizer)
  98. break
  99. else:
  100. # There is a following session, go to it
  101. continue
  102. model.r = r
  103. # Begin the training
  104. simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
  105. ("Batch Size", batch_size),
  106. ("Learning Rate", lr),
  107. ("Outputs/Step (r)", model.r)])
  108. for p in optimizer.param_groups:
  109. p["lr"] = lr
  110. collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
  111. data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
  112. total_iters = len(dataset)
  113. steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
  114. epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
  115. for epoch in range(1, epochs+1):
  116. for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
  117. start_time = time.time()
  118. # Generate stop tokens for training
  119. stop = torch.ones(mels.shape[0], mels.shape[2])
  120. for j, k in enumerate(idx):
  121. stop[j, :int(dataset.metadata[k][4])-1] = 0
  122. texts = texts.to(device)
  123. mels = mels.to(device)
  124. embeds = embeds.to(device)
  125. stop = stop.to(device)
  126. # Forward pass
  127. # Parallelize model onto GPUS using workaround due to python bug
  128. if device.type == "cuda" and torch.cuda.device_count() > 1:
  129. m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
  130. else:
  131. m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
  132. # Backward pass
  133. m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
  134. m2_loss = F.mse_loss(m2_hat, mels)
  135. stop_loss = F.binary_cross_entropy(stop_pred, stop)
  136. loss = m1_loss + m2_loss + stop_loss
  137. optimizer.zero_grad()
  138. loss.backward()
  139. if hparams.tts_clip_grad_norm is not None:
  140. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
  141. if np.isnan(grad_norm.cpu()):
  142. print("grad_norm was NaN!")
  143. optimizer.step()
  144. time_window.append(time.time() - start_time)
  145. loss_window.append(loss.item())
  146. step = model.get_step()
  147. k = step // 1000
  148. msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | " \
  149. f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
  150. stream(msg)
  151. # Backup or save model as appropriate
  152. if backup_every != 0 and step % backup_every == 0 :
  153. backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
  154. model.save(backup_fpath, optimizer)
  155. if save_every != 0 and step % save_every == 0 :
  156. # Must save latest optimizer state to ensure that resuming training
  157. # doesn't produce artifacts
  158. model.save(weights_fpath, optimizer)
  159. # Evaluate model to generate samples
  160. epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
  161. step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
  162. if epoch_eval or step_eval:
  163. for sample_idx in range(hparams.tts_eval_num_samples):
  164. # At most, generate samples equal to number in the batch
  165. if sample_idx + 1 <= len(texts):
  166. # Remove padding from mels using frame length in metadata
  167. mel_length = int(dataset.metadata[idx[sample_idx]][4])
  168. mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
  169. target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
  170. attention_len = mel_length // model.r
  171. eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
  172. mel_prediction=mel_prediction,
  173. target_spectrogram=target_spectrogram,
  174. input_seq=np_now(texts[sample_idx]),
  175. step=step,
  176. plot_dir=plot_dir,
  177. mel_output_dir=mel_output_dir,
  178. wav_dir=wav_dir,
  179. sample_num=sample_idx + 1,
  180. loss=loss,
  181. hparams=hparams)
  182. # Break out of loop to update training schedule
  183. if step >= max_step:
  184. break
  185. # Add line break after every epoch
  186. print("")
  187. def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
  188. plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
  189. # Save some results for evaluation
  190. attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
  191. save_attention(attention, attention_path)
  192. # save predicted mel spectrogram to disk (debug)
  193. mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
  194. np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
  195. # save griffin lim inverted wav for debug (mel -> wav)
  196. wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
  197. wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
  198. audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
  199. # save real and predicted mel-spectrogram plot to disk (control purposes)
  200. spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
  201. title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
  202. plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
  203. target_spectrogram=target_spectrogram,
  204. max_len=target_spectrogram.size // hparams.num_mels)
  205. print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))