data_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import time
  2. import logging
  3. import os
  4. import random
  5. import traceback
  6. import numpy as np
  7. import torch
  8. import torch.utils.data
  9. from tqdm import tqdm
  10. from module import commons
  11. from module.mel_processing import spectrogram_torch
  12. from text import cleaned_text_to_sequence
  13. from utils import load_wav_to_torch, load_filepaths_and_text
  14. import torch.nn.functional as F
  15. from functools import lru_cache
  16. import requests
  17. from scipy.io import wavfile
  18. from io import BytesIO
  19. from my_utils import load_audio
  20. # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
  21. class TextAudioSpeakerLoader(torch.utils.data.Dataset):
  22. """
  23. 1) loads audio, speaker_id, text pairs
  24. 2) normalizes text and converts them to sequences of integers
  25. 3) computes spectrograms from audio files.
  26. """
  27. def __init__(self, hparams, val=False):
  28. exp_dir = hparams.exp_dir
  29. self.path2 = "%s/2-name2text.txt" % exp_dir
  30. self.path4 = "%s/4-cnhubert" % exp_dir
  31. self.path5 = "%s/5-wav32k" % exp_dir
  32. assert os.path.exists(self.path2)
  33. assert os.path.exists(self.path4)
  34. assert os.path.exists(self.path5)
  35. names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
  36. names5 = set(os.listdir(self.path5))
  37. self.phoneme_data = {}
  38. with open(self.path2, "r", encoding="utf8") as f:
  39. lines = f.read().strip("\n").split("\n")
  40. for line in lines:
  41. tmp = line.split("\t")
  42. if (len(tmp) != 4):
  43. continue
  44. self.phoneme_data[tmp[0]] = [tmp[1]]
  45. self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
  46. tmp = self.audiopaths_sid_text
  47. leng = len(tmp)
  48. min_num = 100
  49. if (leng < min_num):
  50. self.audiopaths_sid_text = []
  51. for _ in range(max(2, int(min_num / leng))):
  52. self.audiopaths_sid_text += tmp
  53. self.max_wav_value = hparams.max_wav_value
  54. self.sampling_rate = hparams.sampling_rate
  55. self.filter_length = hparams.filter_length
  56. self.hop_length = hparams.hop_length
  57. self.win_length = hparams.win_length
  58. self.sampling_rate = hparams.sampling_rate
  59. self.val = val
  60. random.seed(1234)
  61. random.shuffle(self.audiopaths_sid_text)
  62. print("phoneme_data_len:", len(self.phoneme_data.keys()))
  63. print("wav_data_len:", len(self.audiopaths_sid_text))
  64. audiopaths_sid_text_new = []
  65. lengths = []
  66. skipped_phone = 0
  67. skipped_dur = 0
  68. for audiopath in tqdm(self.audiopaths_sid_text):
  69. try:
  70. phoneme = self.phoneme_data[audiopath][0]
  71. phoneme = phoneme.split(' ')
  72. phoneme_ids = cleaned_text_to_sequence(phoneme)
  73. except Exception:
  74. print(f"{audiopath} not in self.phoneme_data !")
  75. skipped_phone += 1
  76. continue
  77. size = os.path.getsize("%s/%s" % (self.path5, audiopath))
  78. duration = size / self.sampling_rate / 2
  79. if duration == 0:
  80. print(f"Zero duration for {audiopath}, skipping...")
  81. skipped_dur += 1
  82. continue
  83. if 54 > duration > 0.6 or self.val:
  84. audiopaths_sid_text_new.append([audiopath, phoneme_ids])
  85. lengths.append(size // (2 * self.hop_length))
  86. else:
  87. skipped_dur += 1
  88. continue
  89. print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
  90. print("total left: ", len(audiopaths_sid_text_new))
  91. assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
  92. self.audiopaths_sid_text = audiopaths_sid_text_new
  93. self.lengths = lengths
  94. def get_audio_text_speaker_pair(self, audiopath_sid_text):
  95. audiopath, phoneme_ids = audiopath_sid_text
  96. text = torch.FloatTensor(phoneme_ids)
  97. try:
  98. spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
  99. with torch.no_grad():
  100. ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
  101. if (ssl.shape[-1] != spec.shape[-1]):
  102. typee = ssl.dtype
  103. ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
  104. ssl.requires_grad = False
  105. except:
  106. traceback.print_exc()
  107. spec = torch.zeros(1025, 100)
  108. wav = torch.zeros(1, 100 * self.hop_length)
  109. ssl = torch.zeros(1, 768, 100)
  110. text = text[-1:]
  111. print("load audio or ssl error!!!!!!", audiopath)
  112. return (ssl, spec, wav, text)
  113. def get_audio(self, filename):
  114. audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
  115. audio = torch.FloatTensor(audio_array) # /32768
  116. audio_norm = audio
  117. audio_norm = audio_norm.unsqueeze(0)
  118. spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
  119. center=False)
  120. spec = torch.squeeze(spec, 0)
  121. return spec, audio_norm
  122. def get_sid(self, sid):
  123. sid = torch.LongTensor([int(sid)])
  124. return sid
  125. def __getitem__(self, index):
  126. # with torch.no_grad():
  127. return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
  128. def __len__(self):
  129. return len(self.audiopaths_sid_text)
  130. def random_slice(self, ssl, wav, mel):
  131. assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
  132. "first", ssl.shape, wav.shape)
  133. len_mel = mel.shape[1]
  134. if self.val:
  135. reference_mel = mel[:, :len_mel // 3]
  136. return reference_mel, ssl, wav, mel
  137. dir = random.randint(0, 1)
  138. sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
  139. if dir == 0:
  140. reference_mel = mel[:, :sep_point]
  141. ssl = ssl[:, :, sep_point:]
  142. wav2 = wav[:, sep_point * self.hop_length:]
  143. mel = mel[:, sep_point:]
  144. else:
  145. reference_mel = mel[:, sep_point:]
  146. ssl = ssl[:, :, :sep_point]
  147. wav2 = wav[:, :sep_point * self.hop_length]
  148. mel = mel[:, :sep_point]
  149. assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
  150. ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
  151. return reference_mel, ssl, wav2, mel
  152. class TextAudioSpeakerCollate():
  153. """ Zero-pads model inputs and targets
  154. """
  155. def __init__(self, return_ids=False):
  156. self.return_ids = return_ids
  157. def __call__(self, batch):
  158. """Collate's training batch from normalized text, audio and speaker identities
  159. PARAMS
  160. ------
  161. batch: [text_normalized, spec_normalized, wav_normalized, sid]
  162. """
  163. # Right zero-pad all one-hot text sequences to max input length
  164. _, ids_sorted_decreasing = torch.sort(
  165. torch.LongTensor([x[1].size(1) for x in batch]),
  166. dim=0, descending=True)
  167. max_ssl_len = max([x[0].size(2) for x in batch])
  168. max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
  169. max_spec_len = max([x[1].size(1) for x in batch])
  170. max_spec_len = int(2 * ((max_spec_len // 2) + 1))
  171. max_wav_len = max([x[2].size(1) for x in batch])
  172. max_text_len = max([x[3].size(0) for x in batch])
  173. ssl_lengths = torch.LongTensor(len(batch))
  174. spec_lengths = torch.LongTensor(len(batch))
  175. wav_lengths = torch.LongTensor(len(batch))
  176. text_lengths = torch.LongTensor(len(batch))
  177. spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
  178. wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
  179. ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
  180. text_padded = torch.LongTensor(len(batch), max_text_len)
  181. spec_padded.zero_()
  182. wav_padded.zero_()
  183. ssl_padded.zero_()
  184. text_padded.zero_()
  185. for i in range(len(ids_sorted_decreasing)):
  186. row = batch[ids_sorted_decreasing[i]]
  187. ssl = row[0]
  188. ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
  189. ssl_lengths[i] = ssl.size(2)
  190. spec = row[1]
  191. spec_padded[i, :, :spec.size(1)] = spec
  192. spec_lengths[i] = spec.size(1)
  193. wav = row[2]
  194. wav_padded[i, :, :wav.size(1)] = wav
  195. wav_lengths[i] = wav.size(1)
  196. text = row[3]
  197. text_padded[i, :text.size(0)] = text
  198. text_lengths[i] = text.size(0)
  199. return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
  200. class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
  201. """
  202. Maintain similar input lengths in a batch.
  203. Length groups are specified by boundaries.
  204. Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
  205. It removes samples which are not included in the boundaries.
  206. Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
  207. """
  208. def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
  209. super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
  210. self.lengths = dataset.lengths
  211. self.batch_size = batch_size
  212. self.boundaries = boundaries
  213. self.buckets, self.num_samples_per_bucket = self._create_buckets()
  214. self.total_size = sum(self.num_samples_per_bucket)
  215. self.num_samples = self.total_size // self.num_replicas
  216. def _create_buckets(self):
  217. buckets = [[] for _ in range(len(self.boundaries) - 1)]
  218. for i in range(len(self.lengths)):
  219. length = self.lengths[i]
  220. idx_bucket = self._bisect(length)
  221. if idx_bucket != -1:
  222. buckets[idx_bucket].append(i)
  223. i = len(buckets) - 1
  224. while i >= 0:
  225. if len(buckets[i]) == 0:
  226. buckets.pop(i)
  227. self.boundaries.pop(i + 1)
  228. i -= 1
  229. num_samples_per_bucket = []
  230. for i in range(len(buckets)):
  231. len_bucket = len(buckets[i])
  232. total_batch_size = self.num_replicas * self.batch_size
  233. rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
  234. num_samples_per_bucket.append(len_bucket + rem)
  235. return buckets, num_samples_per_bucket
  236. def __iter__(self):
  237. g = torch.Generator()
  238. g.manual_seed(self.epoch)
  239. indices = []
  240. if self.shuffle:
  241. for bucket in self.buckets:
  242. indices.append(torch.randperm(len(bucket), generator=g).tolist())
  243. else:
  244. for bucket in self.buckets:
  245. indices.append(list(range(len(bucket))))
  246. batches = []
  247. for i in range(len(self.buckets)):
  248. bucket = self.buckets[i]
  249. len_bucket = len(bucket)
  250. ids_bucket = indices[i]
  251. num_samples_bucket = self.num_samples_per_bucket[i]
  252. rem = num_samples_bucket - len_bucket
  253. ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
  254. ids_bucket = ids_bucket[self.rank::self.num_replicas]
  255. for j in range(len(ids_bucket) // self.batch_size):
  256. batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
  257. batches.append(batch)
  258. if self.shuffle:
  259. batch_ids = torch.randperm(len(batches), generator=g).tolist()
  260. batches = [batches[i] for i in batch_ids]
  261. self.batches = batches
  262. assert len(self.batches) * self.batch_size == self.num_samples
  263. return iter(self.batches)
  264. def _bisect(self, x, lo=0, hi=None):
  265. if hi is None:
  266. hi = len(self.boundaries) - 1
  267. if hi > lo:
  268. mid = (hi + lo) // 2
  269. if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
  270. return mid
  271. elif x <= self.boundaries[mid]:
  272. return self._bisect(x, lo, mid)
  273. else:
  274. return self._bisect(x, mid + 1, hi)
  275. else:
  276. return -1
  277. def __len__(self):
  278. return self.num_samples // self.batch_size