dataset.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
  2. import pdb
  3. import sys
  4. # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
  5. import traceback, os
  6. from typing import Dict
  7. from typing import List
  8. import numpy as np
  9. import pandas as pd
  10. import torch, json
  11. from torch.utils.data import DataLoader
  12. from torch.utils.data import Dataset
  13. from transformers import AutoTokenizer
  14. from text import cleaned_text_to_sequence
  15. # from config import exp_dir
  16. def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
  17. seq = sequences[0]
  18. ndim = seq.ndim
  19. if axis < 0:
  20. axis += ndim
  21. dtype = seq.dtype
  22. pad_value = dtype.type(pad_value)
  23. seq_lengths = [seq.shape[axis] for seq in sequences]
  24. max_length = np.max(seq_lengths)
  25. padded_sequences = []
  26. for seq, length in zip(sequences, seq_lengths):
  27. padding = (
  28. [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
  29. )
  30. padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
  31. padded_sequences.append(padded_seq)
  32. batch = np.stack(padded_sequences)
  33. return batch
  34. class Text2SemanticDataset(Dataset):
  35. """dataset class for text tokens to semantic model training."""
  36. def __init__(
  37. self,
  38. phoneme_path: str,
  39. semantic_path: str,
  40. max_sample: int = None,
  41. max_sec: int = 100,
  42. pad_val: int = 1024,
  43. # min value of phoneme/sec
  44. min_ps_ratio: int = 3,
  45. # max value of phoneme/sec
  46. max_ps_ratio: int = 25,
  47. ) -> None:
  48. super().__init__()
  49. self.semantic_data = pd.read_csv(
  50. semantic_path, delimiter="\t", encoding="utf-8"
  51. )
  52. # get dict
  53. self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
  54. self.path3 = "%s/3-bert" % (
  55. os.path.basename(phoneme_path)
  56. ) # "%s/3-bert"%exp_dir#bert_dir
  57. self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
  58. assert os.path.exists(self.path2)
  59. assert os.path.exists(self.path6)
  60. self.phoneme_data = {}
  61. with open(self.path2, "r", encoding="utf8") as f:
  62. lines = f.read().strip("\n").split("\n")
  63. for line in lines:
  64. tmp = line.split("\t")
  65. if len(tmp) != 4:
  66. continue
  67. self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
  68. # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
  69. # pad for semantic tokens
  70. self.PAD: int = pad_val
  71. # self.hz = 25
  72. # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
  73. # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
  74. # self.hz=int(data[:-2])#
  75. self.hz = int(os.environ.get("hz", "25hz")[:-2])
  76. # max seconds of semantic token
  77. self.max_sec = max_sec
  78. self.min_ps_ratio = min_ps_ratio
  79. self.max_ps_ratio = max_ps_ratio
  80. if max_sample is not None:
  81. self.semantic_data = self.semantic_data[:max_sample]
  82. # {idx: (semantic, phoneme)}
  83. # semantic list, phoneme list
  84. self.semantic_phoneme = []
  85. self.item_names = []
  86. self.inited = False
  87. if not self.inited:
  88. # 调用初始化函数
  89. self.init_batch()
  90. self.inited = True
  91. del self.semantic_data
  92. del self.phoneme_data
  93. # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
  94. # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
  95. def init_batch(self):
  96. semantic_data_len = len(self.semantic_data)
  97. phoneme_data_len = len(self.phoneme_data.keys())
  98. print("semantic_data_len:", semantic_data_len)
  99. print("phoneme_data_len:", phoneme_data_len)
  100. print(self.semantic_data)
  101. idx = 0
  102. num_not_in = 0
  103. num_deleted_bigger = 0
  104. num_deleted_ps = 0
  105. for i in range(semantic_data_len):
  106. # 先依次遍历
  107. # get str
  108. item_name = self.semantic_data.iloc[i,0]
  109. # print(self.phoneme_data)
  110. try:
  111. phoneme, word2ph, text = self.phoneme_data[item_name]
  112. except Exception:
  113. traceback.print_exc()
  114. # print(f"{item_name} not in self.phoneme_data !")
  115. num_not_in += 1
  116. continue
  117. semantic_str = self.semantic_data.iloc[i,1]
  118. # get token list
  119. semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
  120. # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
  121. # 过滤掉太长的样本
  122. if (
  123. len(semantic_ids) > self.max_sec * self.hz
  124. ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
  125. num_deleted_bigger += 1
  126. continue
  127. # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
  128. phoneme = phoneme.split(" ")
  129. try:
  130. phoneme_ids = cleaned_text_to_sequence(phoneme)
  131. except:
  132. traceback.print_exc()
  133. # print(f"{item_name} not in self.phoneme_data !")
  134. num_not_in += 1
  135. continue
  136. # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
  137. if (
  138. len(phoneme_ids) > self.max_sec * self.hz / 2.5
  139. ): ###########2:改为恒定限制为semantic/2.5就行
  140. num_deleted_ps += 1
  141. continue
  142. # if len(semantic_ids) > 1000:###########3
  143. # num_deleted_bigger += 1
  144. # continue
  145. ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
  146. if (
  147. ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
  148. ): ##########4#3~25#每秒多少个phone
  149. num_deleted_ps += 1
  150. # print(item_name)
  151. continue
  152. self.semantic_phoneme.append((semantic_ids, phoneme_ids))
  153. idx += 1
  154. self.item_names.append(item_name)
  155. min_num = 100 # 20直接不补#30补了也不存ckpt
  156. leng = len(self.semantic_phoneme)
  157. if leng < min_num:
  158. tmp1 = self.semantic_phoneme
  159. tmp2 = self.item_names
  160. self.semantic_phoneme = []
  161. self.item_names = []
  162. for _ in range(max(2, int(min_num / leng))):
  163. self.semantic_phoneme += tmp1
  164. self.item_names += tmp2
  165. if num_not_in > 0:
  166. print(f"there are {num_not_in} semantic datas not in phoneme datas")
  167. if num_deleted_bigger > 0:
  168. print(
  169. f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
  170. )
  171. if num_deleted_ps > 0:
  172. # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
  173. print(
  174. f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
  175. )
  176. """
  177. there are 31 semantic datas not in phoneme datas
  178. deleted 34 audios who's duration are bigger than 54 seconds
  179. deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
  180. dataset.__len__(): 366463
  181. """
  182. # 345410 for LibriTTS
  183. print("dataset.__len__():", self.__len__())
  184. def __get_item_names__(self) -> List[str]:
  185. return self.item_names
  186. def __len__(self) -> int:
  187. return len(self.semantic_phoneme)
  188. def __getitem__(self, idx: int) -> Dict:
  189. semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
  190. item_name = self.item_names[idx]
  191. phoneme_ids_len = len(phoneme_ids)
  192. # semantic tokens target
  193. semantic_ids_len = len(semantic_ids)
  194. flag = 0
  195. path_bert = "%s/%s.pt" % (self.path3, item_name)
  196. if os.path.exists(path_bert) == True:
  197. bert_feature = torch.load(path_bert, map_location="cpu")
  198. else:
  199. flag = 1
  200. if flag == 1:
  201. # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
  202. bert_feature = None
  203. else:
  204. assert bert_feature.shape[-1] == len(phoneme_ids)
  205. return {
  206. "idx": idx,
  207. "phoneme_ids": phoneme_ids,
  208. "phoneme_ids_len": phoneme_ids_len,
  209. "semantic_ids": semantic_ids,
  210. "semantic_ids_len": semantic_ids_len,
  211. "bert_feature": bert_feature,
  212. }
  213. def get_sample_length(self, idx: int):
  214. semantic_ids = self.semantic_phoneme[idx][0]
  215. sec = 1.0 * len(semantic_ids) / self.hz
  216. return sec
  217. def collate(self, examples: List[Dict]) -> Dict:
  218. sample_index: List[int] = []
  219. phoneme_ids: List[torch.Tensor] = []
  220. phoneme_ids_lens: List[int] = []
  221. semantic_ids: List[torch.Tensor] = []
  222. semantic_ids_lens: List[int] = []
  223. # return
  224. for item in examples:
  225. sample_index.append(item["idx"])
  226. phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
  227. semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
  228. phoneme_ids_lens.append(item["phoneme_ids_len"])
  229. semantic_ids_lens.append(item["semantic_ids_len"])
  230. # pad 0
  231. phoneme_ids = batch_sequences(phoneme_ids)
  232. semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
  233. # # convert each batch to torch.tensor
  234. phoneme_ids = torch.tensor(phoneme_ids)
  235. semantic_ids = torch.tensor(semantic_ids)
  236. phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
  237. semantic_ids_lens = torch.tensor(semantic_ids_lens)
  238. bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
  239. bert_padded.zero_()
  240. for idx, item in enumerate(examples):
  241. bert = item["bert_feature"]
  242. if bert != None:
  243. bert_padded[idx, :, : bert.shape[-1]] = bert
  244. return {
  245. # List[int]
  246. "ids": sample_index,
  247. # torch.Tensor (B, max_phoneme_length)
  248. "phoneme_ids": phoneme_ids,
  249. # torch.Tensor (B)
  250. "phoneme_ids_len": phoneme_ids_lens,
  251. # torch.Tensor (B, max_semantic_ids_length)
  252. "semantic_ids": semantic_ids,
  253. # torch.Tensor (B)
  254. "semantic_ids_len": semantic_ids_lens,
  255. # torch.Tensor (B, 1024, max_phoneme_length)
  256. "bert_feature": bert_padded,
  257. }
  258. if __name__ == "__main__":
  259. root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
  260. dataset = Text2SemanticDataset(
  261. phoneme_path=root_dir + "phoneme_train.npy",
  262. semantic_path=root_dir + "semantic_train.tsv",
  263. )
  264. batch_size = 12
  265. dataloader = DataLoader(
  266. dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
  267. )
  268. for i, batch in enumerate(dataloader):
  269. if i % 1000 == 0:
  270. print(i)
  271. # if i == 0:
  272. # print('batch["ids"]:', batch["ids"])
  273. # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
  274. # batch["phoneme_ids"].shape)
  275. # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
  276. # batch["phoneme_ids_len"].shape)
  277. # print('batch["semantic_ids"]:', batch["semantic_ids"],
  278. # batch["semantic_ids"].shape)
  279. # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
  280. # batch["semantic_ids_len"].shape)