dataset.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import os
  2. import random
  3. import numpy as np
  4. import torch
  5. import torch.utils.data
  6. from tqdm import tqdm
  7. from . import spec_utils
  8. class VocalRemoverValidationSet(torch.utils.data.Dataset):
  9. def __init__(self, patch_list):
  10. self.patch_list = patch_list
  11. def __len__(self):
  12. return len(self.patch_list)
  13. def __getitem__(self, idx):
  14. path = self.patch_list[idx]
  15. data = np.load(path)
  16. X, y = data["X"], data["y"]
  17. X_mag = np.abs(X)
  18. y_mag = np.abs(y)
  19. return X_mag, y_mag
  20. def make_pair(mix_dir, inst_dir):
  21. input_exts = [".wav", ".m4a", ".mp3", ".mp4", ".flac"]
  22. X_list = sorted(
  23. [
  24. os.path.join(mix_dir, fname)
  25. for fname in os.listdir(mix_dir)
  26. if os.path.splitext(fname)[1] in input_exts
  27. ]
  28. )
  29. y_list = sorted(
  30. [
  31. os.path.join(inst_dir, fname)
  32. for fname in os.listdir(inst_dir)
  33. if os.path.splitext(fname)[1] in input_exts
  34. ]
  35. )
  36. filelist = list(zip(X_list, y_list))
  37. return filelist
  38. def train_val_split(dataset_dir, split_mode, val_rate, val_filelist):
  39. if split_mode == "random":
  40. filelist = make_pair(
  41. os.path.join(dataset_dir, "mixtures"),
  42. os.path.join(dataset_dir, "instruments"),
  43. )
  44. random.shuffle(filelist)
  45. if len(val_filelist) == 0:
  46. val_size = int(len(filelist) * val_rate)
  47. train_filelist = filelist[:-val_size]
  48. val_filelist = filelist[-val_size:]
  49. else:
  50. train_filelist = [
  51. pair for pair in filelist if list(pair) not in val_filelist
  52. ]
  53. elif split_mode == "subdirs":
  54. if len(val_filelist) != 0:
  55. raise ValueError(
  56. "The `val_filelist` option is not available in `subdirs` mode"
  57. )
  58. train_filelist = make_pair(
  59. os.path.join(dataset_dir, "training/mixtures"),
  60. os.path.join(dataset_dir, "training/instruments"),
  61. )
  62. val_filelist = make_pair(
  63. os.path.join(dataset_dir, "validation/mixtures"),
  64. os.path.join(dataset_dir, "validation/instruments"),
  65. )
  66. return train_filelist, val_filelist
  67. def augment(X, y, reduction_rate, reduction_mask, mixup_rate, mixup_alpha):
  68. perm = np.random.permutation(len(X))
  69. for i, idx in enumerate(tqdm(perm)):
  70. if np.random.uniform() < reduction_rate:
  71. y[idx] = spec_utils.reduce_vocal_aggressively(
  72. X[idx], y[idx], reduction_mask
  73. )
  74. if np.random.uniform() < 0.5:
  75. # swap channel
  76. X[idx] = X[idx, ::-1]
  77. y[idx] = y[idx, ::-1]
  78. if np.random.uniform() < 0.02:
  79. # mono
  80. X[idx] = X[idx].mean(axis=0, keepdims=True)
  81. y[idx] = y[idx].mean(axis=0, keepdims=True)
  82. if np.random.uniform() < 0.02:
  83. # inst
  84. X[idx] = y[idx]
  85. if np.random.uniform() < mixup_rate and i < len(perm) - 1:
  86. lam = np.random.beta(mixup_alpha, mixup_alpha)
  87. X[idx] = lam * X[idx] + (1 - lam) * X[perm[i + 1]]
  88. y[idx] = lam * y[idx] + (1 - lam) * y[perm[i + 1]]
  89. return X, y
  90. def make_padding(width, cropsize, offset):
  91. left = offset
  92. roi_size = cropsize - left * 2
  93. if roi_size == 0:
  94. roi_size = cropsize
  95. right = roi_size - (width % roi_size) + left
  96. return left, right, roi_size
  97. def make_training_set(filelist, cropsize, patches, sr, hop_length, n_fft, offset):
  98. len_dataset = patches * len(filelist)
  99. X_dataset = np.zeros((len_dataset, 2, n_fft // 2 + 1, cropsize), dtype=np.complex64)
  100. y_dataset = np.zeros((len_dataset, 2, n_fft // 2 + 1, cropsize), dtype=np.complex64)
  101. for i, (X_path, y_path) in enumerate(tqdm(filelist)):
  102. X, y = spec_utils.cache_or_load(X_path, y_path, sr, hop_length, n_fft)
  103. coef = np.max([np.abs(X).max(), np.abs(y).max()])
  104. X, y = X / coef, y / coef
  105. l, r, roi_size = make_padding(X.shape[2], cropsize, offset)
  106. X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode="constant")
  107. y_pad = np.pad(y, ((0, 0), (0, 0), (l, r)), mode="constant")
  108. starts = np.random.randint(0, X_pad.shape[2] - cropsize, patches)
  109. ends = starts + cropsize
  110. for j in range(patches):
  111. idx = i * patches + j
  112. X_dataset[idx] = X_pad[:, :, starts[j] : ends[j]]
  113. y_dataset[idx] = y_pad[:, :, starts[j] : ends[j]]
  114. return X_dataset, y_dataset
  115. def make_validation_set(filelist, cropsize, sr, hop_length, n_fft, offset):
  116. patch_list = []
  117. patch_dir = "cs{}_sr{}_hl{}_nf{}_of{}".format(
  118. cropsize, sr, hop_length, n_fft, offset
  119. )
  120. os.makedirs(patch_dir, exist_ok=True)
  121. for i, (X_path, y_path) in enumerate(tqdm(filelist)):
  122. basename = os.path.splitext(os.path.basename(X_path))[0]
  123. X, y = spec_utils.cache_or_load(X_path, y_path, sr, hop_length, n_fft)
  124. coef = np.max([np.abs(X).max(), np.abs(y).max()])
  125. X, y = X / coef, y / coef
  126. l, r, roi_size = make_padding(X.shape[2], cropsize, offset)
  127. X_pad = np.pad(X, ((0, 0), (0, 0), (l, r)), mode="constant")
  128. y_pad = np.pad(y, ((0, 0), (0, 0), (l, r)), mode="constant")
  129. len_dataset = int(np.ceil(X.shape[2] / roi_size))
  130. for j in range(len_dataset):
  131. outpath = os.path.join(patch_dir, "{}_p{}.npz".format(basename, j))
  132. start = j * roi_size
  133. if not os.path.exists(outpath):
  134. np.savez(
  135. outpath,
  136. X=X_pad[:, :, start : start + cropsize],
  137. y=y_pad[:, :, start : start + cropsize],
  138. )
  139. patch_list.append(outpath)
  140. return VocalRemoverValidationSet(patch_list)