onnx_export.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from module.models_onnx import SynthesizerTrn, symbols
  2. from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
  3. import torch
  4. import torchaudio
  5. from torch import nn
  6. from feature_extractor import cnhubert
  7. cnhubert_base_path = "pretrained_models/chinese-hubert-base"
  8. cnhubert.cnhubert_base_path=cnhubert_base_path
  9. ssl_model = cnhubert.get_model()
  10. from text import cleaned_text_to_sequence
  11. import soundfile
  12. from my_utils import load_audio
  13. import os
  14. import json
  15. def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
  16. hann_window = torch.hann_window(win_size).to(
  17. dtype=y.dtype, device=y.device
  18. )
  19. y = torch.nn.functional.pad(
  20. y.unsqueeze(1),
  21. (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
  22. mode="reflect",
  23. )
  24. y = y.squeeze(1)
  25. spec = torch.stft(
  26. y,
  27. n_fft,
  28. hop_length=hop_size,
  29. win_length=win_size,
  30. window=hann_window,
  31. center=center,
  32. pad_mode="reflect",
  33. normalized=False,
  34. onesided=True,
  35. return_complex=False,
  36. )
  37. spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
  38. return spec
  39. class DictToAttrRecursive(dict):
  40. def __init__(self, input_dict):
  41. super().__init__(input_dict)
  42. for key, value in input_dict.items():
  43. if isinstance(value, dict):
  44. value = DictToAttrRecursive(value)
  45. self[key] = value
  46. setattr(self, key, value)
  47. def __getattr__(self, item):
  48. try:
  49. return self[item]
  50. except KeyError:
  51. raise AttributeError(f"Attribute {item} not found")
  52. def __setattr__(self, key, value):
  53. if isinstance(value, dict):
  54. value = DictToAttrRecursive(value)
  55. super(DictToAttrRecursive, self).__setitem__(key, value)
  56. super().__setattr__(key, value)
  57. def __delattr__(self, item):
  58. try:
  59. del self[item]
  60. except KeyError:
  61. raise AttributeError(f"Attribute {item} not found")
  62. class T2SEncoder(nn.Module):
  63. def __init__(self, t2s, vits):
  64. super().__init__()
  65. self.encoder = t2s.onnx_encoder
  66. self.vits = vits
  67. def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
  68. codes = self.vits.extract_latent(ssl_content)
  69. prompt_semantic = codes[0, 0]
  70. bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
  71. all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
  72. bert = bert.unsqueeze(0)
  73. prompt = prompt_semantic.unsqueeze(0)
  74. return self.encoder(all_phoneme_ids, bert), prompt
  75. class T2SModel(nn.Module):
  76. def __init__(self, t2s_path, vits_model):
  77. super().__init__()
  78. dict_s1 = torch.load(t2s_path, map_location="cpu")
  79. self.config = dict_s1["config"]
  80. self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
  81. self.t2s_model.load_state_dict(dict_s1["weight"])
  82. self.t2s_model.eval()
  83. self.vits_model = vits_model.vq_model
  84. self.hz = 50
  85. self.max_sec = self.config["data"]["max_sec"]
  86. self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
  87. self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
  88. self.t2s_model = self.t2s_model.model
  89. self.t2s_model.init_onnx()
  90. self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
  91. self.first_stage_decoder = self.t2s_model.first_stage_decoder
  92. self.stage_decoder = self.t2s_model.stage_decoder
  93. #self.t2s_model = torch.jit.script(self.t2s_model)
  94. def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
  95. early_stop_num = self.t2s_model.early_stop_num
  96. #[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
  97. x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
  98. prefix_len = prompts.shape[1]
  99. #[1,N,512] [1,N]
  100. y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
  101. stop = False
  102. for idx in range(1, 1500):
  103. #[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
  104. enco = self.stage_decoder(y, k, v, y_emb, x_example)
  105. y, k, v, y_emb, logits, samples = enco
  106. if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
  107. stop = True
  108. if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
  109. stop = True
  110. if stop:
  111. break
  112. y[0, -1] = 0
  113. return y[:, -idx:].unsqueeze(0)
  114. def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
  115. #self.onnx_encoder = torch.jit.script(self.onnx_encoder)
  116. if dynamo:
  117. export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
  118. onnx_encoder_export_output = torch.onnx.dynamo_export(
  119. self.onnx_encoder,
  120. (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
  121. export_options=export_options
  122. )
  123. onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
  124. return
  125. torch.onnx.export(
  126. self.onnx_encoder,
  127. (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
  128. f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
  129. input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
  130. output_names=["x", "prompts"],
  131. dynamic_axes={
  132. "ref_seq": {1 : "ref_length"},
  133. "text_seq": {1 : "text_length"},
  134. "ref_bert": {0 : "ref_length"},
  135. "text_bert": {0 : "text_length"},
  136. "ssl_content": {2 : "ssl_length"},
  137. },
  138. opset_version=16
  139. )
  140. x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
  141. torch.onnx.export(
  142. self.first_stage_decoder,
  143. (x, prompts),
  144. f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
  145. input_names=["x", "prompts"],
  146. output_names=["y", "k", "v", "y_emb", "x_example"],
  147. dynamic_axes={
  148. "x": {1 : "x_length"},
  149. "prompts": {1 : "prompts_length"},
  150. },
  151. verbose=False,
  152. opset_version=16
  153. )
  154. y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
  155. torch.onnx.export(
  156. self.stage_decoder,
  157. (y, k, v, y_emb, x_example),
  158. f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
  159. input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
  160. output_names=["y", "k", "v", "y_emb", "logits", "samples"],
  161. dynamic_axes={
  162. "iy": {1 : "iy_length"},
  163. "ik": {1 : "ik_length"},
  164. "iv": {1 : "iv_length"},
  165. "iy_emb": {1 : "iy_emb_length"},
  166. "ix_example": {1 : "ix_example_length"},
  167. },
  168. verbose=False,
  169. opset_version=16
  170. )
  171. class VitsModel(nn.Module):
  172. def __init__(self, vits_path):
  173. super().__init__()
  174. dict_s2 = torch.load(vits_path,map_location="cpu")
  175. self.hps = dict_s2["config"]
  176. self.hps = DictToAttrRecursive(self.hps)
  177. self.hps.model.semantic_frame_rate = "25hz"
  178. self.vq_model = SynthesizerTrn(
  179. self.hps.data.filter_length // 2 + 1,
  180. self.hps.train.segment_size // self.hps.data.hop_length,
  181. n_speakers=self.hps.data.n_speakers,
  182. **self.hps.model
  183. )
  184. self.vq_model.eval()
  185. self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
  186. def forward(self, text_seq, pred_semantic, ref_audio):
  187. refer = spectrogram_torch(
  188. ref_audio,
  189. self.hps.data.filter_length,
  190. self.hps.data.sampling_rate,
  191. self.hps.data.hop_length,
  192. self.hps.data.win_length,
  193. center=False
  194. )
  195. return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
  196. class GptSoVits(nn.Module):
  197. def __init__(self, vits, t2s):
  198. super().__init__()
  199. self.vits = vits
  200. self.t2s = t2s
  201. def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
  202. pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
  203. audio = self.vits(text_seq, pred_semantic, ref_audio)
  204. if debug:
  205. import onnxruntime
  206. sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
  207. audio1 = sess.run(None, {
  208. "text_seq" : text_seq.detach().cpu().numpy(),
  209. "pred_semantic" : pred_semantic.detach().cpu().numpy(),
  210. "ref_audio" : ref_audio.detach().cpu().numpy()
  211. })
  212. return audio, audio1
  213. return audio
  214. def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
  215. self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
  216. pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
  217. torch.onnx.export(
  218. self.vits,
  219. (text_seq, pred_semantic, ref_audio),
  220. f"onnx/{project_name}/{project_name}_vits.onnx",
  221. input_names=["text_seq", "pred_semantic", "ref_audio"],
  222. output_names=["audio"],
  223. dynamic_axes={
  224. "text_seq": {1 : "text_length"},
  225. "pred_semantic": {2 : "pred_length"},
  226. "ref_audio": {1 : "audio_length"},
  227. },
  228. opset_version=17,
  229. verbose=False
  230. )
  231. class SSLModel(nn.Module):
  232. def __init__(self):
  233. super().__init__()
  234. self.ssl = ssl_model
  235. def forward(self, ref_audio_16k):
  236. return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
  237. def export(vits_path, gpt_path, project_name):
  238. vits = VitsModel(vits_path)
  239. gpt = T2SModel(gpt_path, vits)
  240. gpt_sovits = GptSoVits(vits, gpt)
  241. ssl = SSLModel()
  242. ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
  243. text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
  244. ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
  245. text_bert = torch.randn((text_seq.shape[1], 1024)).float()
  246. ref_audio = torch.randn((1, 48000 * 5)).float()
  247. # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
  248. ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
  249. ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
  250. try:
  251. os.mkdir(f"onnx/{project_name}")
  252. except:
  253. pass
  254. ssl_content = ssl(ref_audio_16k).float()
  255. debug = False
  256. if debug:
  257. a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
  258. soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
  259. soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
  260. return
  261. a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
  262. soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
  263. gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
  264. MoeVSConf = {
  265. "Folder" : f"{project_name}",
  266. "Name" : f"{project_name}",
  267. "Type" : "GPT-SoVits",
  268. "Rate" : vits.hps.data.sampling_rate,
  269. "NumLayers": gpt.t2s_model.num_layers,
  270. "EmbeddingDim": gpt.t2s_model.embedding_dim,
  271. "Dict": "BasicDict",
  272. "BertPath": "chinese-roberta-wwm-ext-large",
  273. "Symbol": symbols,
  274. "AddBlank": False
  275. }
  276. MoeVSConfJson = json.dumps(MoeVSConf)
  277. with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
  278. json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
  279. if __name__ == "__main__":
  280. try:
  281. os.mkdir("onnx")
  282. except:
  283. pass
  284. gpt_path = "GPT_weights/nahida-e25.ckpt"
  285. vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
  286. exp_path = "nahida"
  287. export(vits_path, gpt_path, exp_path)
  288. # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)