train.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import time
  2. from pathlib import Path
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import optim
  7. from torch.utils.data import DataLoader
  8. import vocoder.hparams as hp
  9. from vocoder.display import stream, simple_table
  10. from vocoder.distribution import discretized_mix_logistic_loss
  11. from vocoder.gen_wavernn import gen_testset
  12. from vocoder.models.fatchord_version import WaveRNN
  13. from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder
  14. def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, save_every: int,
  15. backup_every: int, force_restart: bool):
  16. # Check to make sure the hop length is correctly factorised
  17. assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
  18. # Instantiate the model
  19. print("Initializing the model...")
  20. model = WaveRNN(
  21. rnn_dims=hp.voc_rnn_dims,
  22. fc_dims=hp.voc_fc_dims,
  23. bits=hp.bits,
  24. pad=hp.voc_pad,
  25. upsample_factors=hp.voc_upsample_factors,
  26. feat_dims=hp.num_mels,
  27. compute_dims=hp.voc_compute_dims,
  28. res_out_dims=hp.voc_res_out_dims,
  29. res_blocks=hp.voc_res_blocks,
  30. hop_length=hp.hop_length,
  31. sample_rate=hp.sample_rate,
  32. mode=hp.voc_mode
  33. )
  34. if torch.cuda.is_available():
  35. model = model.cuda()
  36. # Initialize the optimizer
  37. optimizer = optim.Adam(model.parameters())
  38. for p in optimizer.param_groups:
  39. p["lr"] = hp.voc_lr
  40. loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss
  41. # Load the weights
  42. model_dir = models_dir / run_id
  43. model_dir.mkdir(exist_ok=True)
  44. weights_fpath = model_dir / "vocoder.pt"
  45. if force_restart or not weights_fpath.exists():
  46. print("\nStarting the training of WaveRNN from scratch\n")
  47. model.save(weights_fpath, optimizer)
  48. else:
  49. print("\nLoading weights at %s" % weights_fpath)
  50. model.load(weights_fpath, optimizer)
  51. print("WaveRNN weights loaded from step %d" % model.step)
  52. # Initialize the dataset
  53. metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
  54. voc_dir.joinpath("synthesized.txt")
  55. mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
  56. wav_dir = syn_dir.joinpath("audio")
  57. dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
  58. test_loader = DataLoader(dataset, batch_size=1, shuffle=True)
  59. # Begin the training
  60. simple_table([('Batch size', hp.voc_batch_size),
  61. ('LR', hp.voc_lr),
  62. ('Sequence Len', hp.voc_seq_len)])
  63. for epoch in range(1, 350):
  64. data_loader = DataLoader(dataset, hp.voc_batch_size, shuffle=True, num_workers=2, collate_fn=collate_vocoder)
  65. start = time.time()
  66. running_loss = 0.
  67. for i, (x, y, m) in enumerate(data_loader, 1):
  68. if torch.cuda.is_available():
  69. x, m, y = x.cuda(), m.cuda(), y.cuda()
  70. # Forward pass
  71. y_hat = model(x, m)
  72. if model.mode == 'RAW':
  73. y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
  74. elif model.mode == 'MOL':
  75. y = y.float()
  76. y = y.unsqueeze(-1)
  77. # Backward pass
  78. loss = loss_func(y_hat, y)
  79. optimizer.zero_grad()
  80. loss.backward()
  81. optimizer.step()
  82. running_loss += loss.item()
  83. speed = i / (time.time() - start)
  84. avg_loss = running_loss / i
  85. step = model.get_step()
  86. k = step // 1000
  87. if backup_every != 0 and step % backup_every == 0 :
  88. model.checkpoint(model_dir, optimizer)
  89. if save_every != 0 and step % save_every == 0 :
  90. model.save(weights_fpath, optimizer)
  91. msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
  92. f"Loss: {avg_loss:.4f} | {speed:.1f} " \
  93. f"steps/s | Step: {k}k | "
  94. stream(msg)
  95. gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
  96. hp.voc_target, hp.voc_overlap, model_dir)
  97. print("")