train.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from pathlib import Path
  2. import torch
  3. from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
  4. from encoder.model import SpeakerEncoder
  5. from encoder.params_model import *
  6. from encoder.visualizations import Visualizations
  7. from utils.profiler import Profiler
  8. def sync(device: torch.device):
  9. # For correct profiling (cuda operations are async)
  10. if device.type == "cuda":
  11. torch.cuda.synchronize(device)
  12. def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
  13. backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
  14. no_visdom: bool):
  15. # Create a dataset and a dataloader
  16. dataset = SpeakerVerificationDataset(clean_data_root)
  17. loader = SpeakerVerificationDataLoader(
  18. dataset,
  19. speakers_per_batch,
  20. utterances_per_speaker,
  21. num_workers=4,
  22. )
  23. # Setup the device on which to run the forward pass and the loss. These can be different,
  24. # because the forward pass is faster on the GPU whereas the loss is often (depending on your
  25. # hyperparameters) faster on the CPU.
  26. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  27. # FIXME: currently, the gradient is None if loss_device is cuda
  28. loss_device = torch.device("cpu")
  29. # Create the model and the optimizer
  30. model = SpeakerEncoder(device, loss_device)
  31. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
  32. init_step = 1
  33. # Configure file path for the model
  34. model_dir = models_dir / run_id
  35. model_dir.mkdir(exist_ok=True, parents=True)
  36. state_fpath = model_dir / "encoder.pt"
  37. # Load any existing model
  38. if not force_restart:
  39. if state_fpath.exists():
  40. print("Found existing model \"%s\", loading it and resuming training." % run_id)
  41. checkpoint = torch.load(state_fpath)
  42. init_step = checkpoint["step"]
  43. model.load_state_dict(checkpoint["model_state"])
  44. optimizer.load_state_dict(checkpoint["optimizer_state"])
  45. optimizer.param_groups[0]["lr"] = learning_rate_init
  46. else:
  47. print("No model \"%s\" found, starting training from scratch." % run_id)
  48. else:
  49. print("Starting the training from scratch.")
  50. model.train()
  51. # Initialize the visualization environment
  52. vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
  53. vis.log_dataset(dataset)
  54. vis.log_params()
  55. device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
  56. vis.log_implementation({"Device": device_name})
  57. # Training loop
  58. profiler = Profiler(summarize_every=10, disabled=False)
  59. for step, speaker_batch in enumerate(loader, init_step):
  60. profiler.tick("Blocking, waiting for batch (threaded)")
  61. # Forward pass
  62. inputs = torch.from_numpy(speaker_batch.data).to(device)
  63. sync(device)
  64. profiler.tick("Data to %s" % device)
  65. embeds = model(inputs)
  66. sync(device)
  67. profiler.tick("Forward pass")
  68. embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
  69. loss, eer = model.loss(embeds_loss)
  70. sync(loss_device)
  71. profiler.tick("Loss")
  72. # Backward pass
  73. model.zero_grad()
  74. loss.backward()
  75. profiler.tick("Backward pass")
  76. model.do_gradient_ops()
  77. optimizer.step()
  78. profiler.tick("Parameter update")
  79. # Update visualizations
  80. # learning_rate = optimizer.param_groups[0]["lr"]
  81. vis.update(loss.item(), eer, step)
  82. # Draw projections and save them to the backup folder
  83. if umap_every != 0 and step % umap_every == 0:
  84. print("Drawing and saving projections (step %d)" % step)
  85. projection_fpath = model_dir / f"umap_{step:06d}.png"
  86. embeds = embeds.detach().cpu().numpy()
  87. vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
  88. vis.save()
  89. # Overwrite the latest version of the model
  90. if save_every != 0 and step % save_every == 0:
  91. print("Saving the model (step %d)" % step)
  92. torch.save({
  93. "step": step + 1,
  94. "model_state": model.state_dict(),
  95. "optimizer_state": optimizer.state_dict(),
  96. }, state_fpath)
  97. # Make a backup
  98. if backup_every != 0 and step % backup_every == 0:
  99. print("Making a backup (step %d)" % step)
  100. backup_fpath = model_dir / f"encoder_{step:06d}.bak"
  101. torch.save({
  102. "step": step + 1,
  103. "model_state": model.state_dict(),
  104. "optimizer_state": optimizer.state_dict(),
  105. }, backup_fpath)
  106. profiler.tick("Extras (visualizations, saving)")