synthesizer_dataset.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import torch
  2. from torch.utils.data import Dataset
  3. import numpy as np
  4. from pathlib import Path
  5. from synthesizer.utils.text import text_to_sequence
  6. class SynthesizerDataset(Dataset):
  7. def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
  8. print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
  9. with metadata_fpath.open("r") as metadata_file:
  10. metadata = [line.split("|") for line in metadata_file]
  11. mel_fnames = [x[1] for x in metadata if int(x[4])]
  12. mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
  13. embed_fnames = [x[2] for x in metadata if int(x[4])]
  14. embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
  15. self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
  16. self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
  17. self.metadata = metadata
  18. self.hparams = hparams
  19. print("Found %d samples" % len(self.samples_fpaths))
  20. def __getitem__(self, index):
  21. # Sometimes index may be a list of 2 (not sure why this happens)
  22. # If that is the case, return a single item corresponding to first element in index
  23. if index is list:
  24. index = index[0]
  25. mel_path, embed_path = self.samples_fpaths[index]
  26. mel = np.load(mel_path).T.astype(np.float32)
  27. # Load the embed
  28. embed = np.load(embed_path)
  29. # Get the text and clean it
  30. text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
  31. # Convert the list returned by text_to_sequence to a numpy array
  32. text = np.asarray(text).astype(np.int32)
  33. return text, mel.astype(np.float32), embed.astype(np.float32), index
  34. def __len__(self):
  35. return len(self.samples_fpaths)
  36. def collate_synthesizer(batch, r, hparams):
  37. # Text
  38. x_lens = [len(x[0]) for x in batch]
  39. max_x_len = max(x_lens)
  40. chars = [pad1d(x[0], max_x_len) for x in batch]
  41. chars = np.stack(chars)
  42. # Mel spectrogram
  43. spec_lens = [x[1].shape[-1] for x in batch]
  44. max_spec_len = max(spec_lens) + 1
  45. if max_spec_len % r != 0:
  46. max_spec_len += r - max_spec_len % r
  47. # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
  48. # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
  49. if hparams.symmetric_mels:
  50. mel_pad_value = -1 * hparams.max_abs_value
  51. else:
  52. mel_pad_value = 0
  53. mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
  54. mel = np.stack(mel)
  55. # Speaker embedding (SV2TTS)
  56. embeds = np.array([x[2] for x in batch])
  57. # Index (for vocoder preprocessing)
  58. indices = [x[3] for x in batch]
  59. # Convert all to tensor
  60. chars = torch.tensor(chars).long()
  61. mel = torch.tensor(mel)
  62. embeds = torch.tensor(embeds)
  63. return chars, mel, embeds, indices
  64. def pad1d(x, max_len, pad_value=0):
  65. return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
  66. def pad2d(x, max_len, pad_value=0):
  67. return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)