vocoder_dataset.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from torch.utils.data import Dataset
  2. from pathlib import Path
  3. from vocoder import audio
  4. import vocoder.hparams as hp
  5. import numpy as np
  6. import torch
  7. class VocoderDataset(Dataset):
  8. def __init__(self, metadata_fpath: Path, mel_dir: Path, wav_dir: Path):
  9. print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, wav_dir))
  10. with metadata_fpath.open("r") as metadata_file:
  11. metadata = [line.split("|") for line in metadata_file]
  12. gta_fnames = [x[1] for x in metadata if int(x[4])]
  13. gta_fpaths = [mel_dir.joinpath(fname) for fname in gta_fnames]
  14. wav_fnames = [x[0] for x in metadata if int(x[4])]
  15. wav_fpaths = [wav_dir.joinpath(fname) for fname in wav_fnames]
  16. self.samples_fpaths = list(zip(gta_fpaths, wav_fpaths))
  17. print("Found %d samples" % len(self.samples_fpaths))
  18. def __getitem__(self, index):
  19. mel_path, wav_path = self.samples_fpaths[index]
  20. # Load the mel spectrogram and adjust its range to [-1, 1]
  21. mel = np.load(mel_path).T.astype(np.float32) / hp.mel_max_abs_value
  22. # Load the wav
  23. wav = np.load(wav_path)
  24. if hp.apply_preemphasis:
  25. wav = audio.pre_emphasis(wav)
  26. wav = np.clip(wav, -1, 1)
  27. # Fix for missing padding # TODO: settle on whether this is any useful
  28. r_pad = (len(wav) // hp.hop_length + 1) * hp.hop_length - len(wav)
  29. wav = np.pad(wav, (0, r_pad), mode='constant')
  30. assert len(wav) >= mel.shape[1] * hp.hop_length
  31. wav = wav[:mel.shape[1] * hp.hop_length]
  32. assert len(wav) % hp.hop_length == 0
  33. # Quantize the wav
  34. if hp.voc_mode == 'RAW':
  35. if hp.mu_law:
  36. quant = audio.encode_mu_law(wav, mu=2 ** hp.bits)
  37. else:
  38. quant = audio.float_2_label(wav, bits=hp.bits)
  39. elif hp.voc_mode == 'MOL':
  40. quant = audio.float_2_label(wav, bits=16)
  41. return mel.astype(np.float32), quant.astype(np.int64)
  42. def __len__(self):
  43. return len(self.samples_fpaths)
  44. def collate_vocoder(batch):
  45. mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad
  46. max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch]
  47. mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
  48. sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets]
  49. mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
  50. labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)]
  51. mels = np.stack(mels).astype(np.float32)
  52. labels = np.stack(labels).astype(np.int64)
  53. mels = torch.tensor(mels)
  54. labels = torch.tensor(labels).long()
  55. x = labels[:, :hp.voc_seq_len]
  56. y = labels[:, 1:]
  57. bits = 16 if hp.voc_mode == 'MOL' else hp.bits
  58. x = audio.label_2_float(x.float(), bits)
  59. if hp.voc_mode == 'MOL' :
  60. y = audio.label_2_float(y.float(), bits)
  61. return x, y, mels