speaker_verification_dataset.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from encoder.data_objects.random_cycler import RandomCycler
  2. from encoder.data_objects.speaker_batch import SpeakerBatch
  3. from encoder.data_objects.speaker import Speaker
  4. from encoder.params_data import partials_n_frames
  5. from torch.utils.data import Dataset, DataLoader
  6. from pathlib import Path
  7. # TODO: improve with a pool of speakers for data efficiency
  8. class SpeakerVerificationDataset(Dataset):
  9. def __init__(self, datasets_root: Path):
  10. self.root = datasets_root
  11. speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
  12. if len(speaker_dirs) == 0:
  13. raise Exception("No speakers found. Make sure you are pointing to the directory "
  14. "containing all preprocessed speaker directories.")
  15. self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
  16. self.speaker_cycler = RandomCycler(self.speakers)
  17. def __len__(self):
  18. return int(1e10)
  19. def __getitem__(self, index):
  20. return next(self.speaker_cycler)
  21. def get_logs(self):
  22. log_string = ""
  23. for log_fpath in self.root.glob("*.txt"):
  24. with log_fpath.open("r") as log_file:
  25. log_string += "".join(log_file.readlines())
  26. return log_string
  27. class SpeakerVerificationDataLoader(DataLoader):
  28. def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
  29. batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
  30. worker_init_fn=None):
  31. self.utterances_per_speaker = utterances_per_speaker
  32. super().__init__(
  33. dataset=dataset,
  34. batch_size=speakers_per_batch,
  35. shuffle=False,
  36. sampler=sampler,
  37. batch_sampler=batch_sampler,
  38. num_workers=num_workers,
  39. collate_fn=self.collate,
  40. pin_memory=pin_memory,
  41. drop_last=False,
  42. timeout=timeout,
  43. worker_init_fn=worker_init_fn
  44. )
  45. def collate(self, speakers):
  46. return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)