inference_webui.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. '''
  2. 按中英混合识别
  3. 按日英混合识别
  4. 多语种启动切分识别语种
  5. 全部按中文识别
  6. 全部按英文识别
  7. 全部按日文识别
  8. '''
  9. import os, re, logging
  10. import LangSegment
  11. logging.getLogger("markdown_it").setLevel(logging.ERROR)
  12. logging.getLogger("urllib3").setLevel(logging.ERROR)
  13. logging.getLogger("httpcore").setLevel(logging.ERROR)
  14. logging.getLogger("httpx").setLevel(logging.ERROR)
  15. logging.getLogger("asyncio").setLevel(logging.ERROR)
  16. logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
  17. logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
  18. import pdb
  19. if os.path.exists("./gweight.txt"):
  20. with open("./gweight.txt", 'r', encoding="utf-8") as file:
  21. gweight_data = file.read()
  22. gpt_path = os.environ.get(
  23. "gpt_path", gweight_data)
  24. else:
  25. gpt_path = os.environ.get(
  26. "gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
  27. if os.path.exists("./sweight.txt"):
  28. with open("./sweight.txt", 'r', encoding="utf-8") as file:
  29. sweight_data = file.read()
  30. sovits_path = os.environ.get("sovits_path", sweight_data)
  31. else:
  32. sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
  33. # gpt_path = os.environ.get(
  34. # "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
  35. # )
  36. # sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
  37. cnhubert_base_path = os.environ.get(
  38. "cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
  39. )
  40. bert_path = os.environ.get(
  41. "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
  42. )
  43. infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
  44. infer_ttswebui = int(infer_ttswebui)
  45. is_share = os.environ.get("is_share", "False")
  46. is_share = eval(is_share)
  47. if "_CUDA_VISIBLE_DEVICES" in os.environ:
  48. os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
  49. is_half = eval(os.environ.get("is_half", "True"))
  50. import gradio as gr
  51. from transformers import AutoModelForMaskedLM, AutoTokenizer
  52. import numpy as np
  53. import librosa, torch
  54. from feature_extractor import cnhubert
  55. cnhubert.cnhubert_base_path = cnhubert_base_path
  56. from module.models import SynthesizerTrn
  57. from AR.models.t2s_lightning_module import Text2SemanticLightningModule
  58. from text import cleaned_text_to_sequence
  59. from text.cleaner import clean_text
  60. from time import time as ttime
  61. from module.mel_processing import spectrogram_torch
  62. from my_utils import load_audio
  63. from tools.i18n.i18n import I18nAuto
  64. i18n = I18nAuto()
  65. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
  66. if torch.cuda.is_available():
  67. device = "cuda"
  68. elif torch.backends.mps.is_available():
  69. device = "mps"
  70. else:
  71. device = "cpu"
  72. tokenizer = AutoTokenizer.from_pretrained(bert_path)
  73. bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
  74. if is_half == True:
  75. bert_model = bert_model.half().to(device)
  76. else:
  77. bert_model = bert_model.to(device)
  78. def get_bert_feature(text, word2ph):
  79. with torch.no_grad():
  80. inputs = tokenizer(text, return_tensors="pt")
  81. for i in inputs:
  82. inputs[i] = inputs[i].to(device)
  83. res = bert_model(**inputs, output_hidden_states=True)
  84. res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
  85. assert len(word2ph) == len(text)
  86. phone_level_feature = []
  87. for i in range(len(word2ph)):
  88. repeat_feature = res[i].repeat(word2ph[i], 1)
  89. phone_level_feature.append(repeat_feature)
  90. phone_level_feature = torch.cat(phone_level_feature, dim=0)
  91. return phone_level_feature.T
  92. class DictToAttrRecursive(dict):
  93. def __init__(self, input_dict):
  94. super().__init__(input_dict)
  95. for key, value in input_dict.items():
  96. if isinstance(value, dict):
  97. value = DictToAttrRecursive(value)
  98. self[key] = value
  99. setattr(self, key, value)
  100. def __getattr__(self, item):
  101. try:
  102. return self[item]
  103. except KeyError:
  104. raise AttributeError(f"Attribute {item} not found")
  105. def __setattr__(self, key, value):
  106. if isinstance(value, dict):
  107. value = DictToAttrRecursive(value)
  108. super(DictToAttrRecursive, self).__setitem__(key, value)
  109. super().__setattr__(key, value)
  110. def __delattr__(self, item):
  111. try:
  112. del self[item]
  113. except KeyError:
  114. raise AttributeError(f"Attribute {item} not found")
  115. ssl_model = cnhubert.get_model()
  116. if is_half == True:
  117. ssl_model = ssl_model.half().to(device)
  118. else:
  119. ssl_model = ssl_model.to(device)
  120. def change_sovits_weights(sovits_path):
  121. global vq_model, hps
  122. dict_s2 = torch.load(sovits_path, map_location="cpu")
  123. hps = dict_s2["config"]
  124. hps = DictToAttrRecursive(hps)
  125. hps.model.semantic_frame_rate = "25hz"
  126. vq_model = SynthesizerTrn(
  127. hps.data.filter_length // 2 + 1,
  128. hps.train.segment_size // hps.data.hop_length,
  129. n_speakers=hps.data.n_speakers,
  130. **hps.model
  131. )
  132. if ("pretrained" not in sovits_path):
  133. del vq_model.enc_q
  134. if is_half == True:
  135. vq_model = vq_model.half().to(device)
  136. else:
  137. vq_model = vq_model.to(device)
  138. vq_model.eval()
  139. print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
  140. with open("./sweight.txt", "w", encoding="utf-8") as f:
  141. f.write(sovits_path)
  142. change_sovits_weights(sovits_path)
  143. def change_gpt_weights(gpt_path):
  144. global hz, max_sec, t2s_model, config
  145. hz = 50
  146. dict_s1 = torch.load(gpt_path, map_location="cpu")
  147. config = dict_s1["config"]
  148. max_sec = config["data"]["max_sec"]
  149. t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
  150. t2s_model.load_state_dict(dict_s1["weight"])
  151. if is_half == True:
  152. t2s_model = t2s_model.half()
  153. t2s_model = t2s_model.to(device)
  154. t2s_model.eval()
  155. total = sum([param.nelement() for param in t2s_model.parameters()])
  156. print("Number of parameter: %.2fM" % (total / 1e6))
  157. with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
  158. change_gpt_weights(gpt_path)
  159. def get_spepc(hps, filename):
  160. audio = load_audio(filename, int(hps.data.sampling_rate))
  161. audio = torch.FloatTensor(audio)
  162. audio_norm = audio
  163. audio_norm = audio_norm.unsqueeze(0)
  164. spec = spectrogram_torch(
  165. audio_norm,
  166. hps.data.filter_length,
  167. hps.data.sampling_rate,
  168. hps.data.hop_length,
  169. hps.data.win_length,
  170. center=False,
  171. )
  172. return spec
  173. dict_language = {
  174. i18n("中文"): "all_zh",#全部按中文识别
  175. i18n("英文"): "en",#全部按英文识别#######不变
  176. i18n("日文"): "all_ja",#全部按日文识别
  177. i18n("中英混合"): "zh",#按中英混合识别####不变
  178. i18n("日英混合"): "ja",#按日英混合识别####不变
  179. i18n("多语种混合"): "auto",#多语种启动切分识别语种
  180. }
  181. def splite_en_inf(sentence, language):
  182. pattern = re.compile(r'[a-zA-Z ]+')
  183. textlist = []
  184. langlist = []
  185. pos = 0
  186. for match in pattern.finditer(sentence):
  187. start, end = match.span()
  188. if start > pos:
  189. textlist.append(sentence[pos:start])
  190. langlist.append(language)
  191. textlist.append(sentence[start:end])
  192. langlist.append("en")
  193. pos = end
  194. if pos < len(sentence):
  195. textlist.append(sentence[pos:])
  196. langlist.append(language)
  197. # Merge punctuation into previous word
  198. for i in range(len(textlist)-1, 0, -1):
  199. if re.match(r'^[\W_]+$', textlist[i]):
  200. textlist[i-1] += textlist[i]
  201. del textlist[i]
  202. del langlist[i]
  203. # Merge consecutive words with the same language tag
  204. i = 0
  205. while i < len(langlist) - 1:
  206. if langlist[i] == langlist[i+1]:
  207. textlist[i] += textlist[i+1]
  208. del textlist[i+1]
  209. del langlist[i+1]
  210. else:
  211. i += 1
  212. return textlist, langlist
  213. def clean_text_inf(text, language):
  214. formattext = ""
  215. language = language.replace("all_","")
  216. for tmp in LangSegment.getTexts(text):
  217. if language == "ja":
  218. if tmp["lang"] == language or tmp["lang"] == "zh":
  219. formattext += tmp["text"] + " "
  220. continue
  221. if tmp["lang"] == language:
  222. formattext += tmp["text"] + " "
  223. while " " in formattext:
  224. formattext = formattext.replace(" ", " ")
  225. phones, word2ph, norm_text = clean_text(formattext, language)
  226. phones = cleaned_text_to_sequence(phones)
  227. return phones, word2ph, norm_text
  228. dtype=torch.float16 if is_half == True else torch.float32
  229. def get_bert_inf(phones, word2ph, norm_text, language):
  230. language=language.replace("all_","")
  231. if language == "zh":
  232. bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
  233. else:
  234. bert = torch.zeros(
  235. (1024, len(phones)),
  236. dtype=torch.float16 if is_half == True else torch.float32,
  237. ).to(device)
  238. return bert
  239. def nonen_clean_text_inf(text, language):
  240. if(language!="auto"):
  241. textlist, langlist = splite_en_inf(text, language)
  242. else:
  243. textlist=[]
  244. langlist=[]
  245. for tmp in LangSegment.getTexts(text):
  246. langlist.append(tmp["lang"])
  247. textlist.append(tmp["text"])
  248. phones_list = []
  249. word2ph_list = []
  250. norm_text_list = []
  251. for i in range(len(textlist)):
  252. lang = langlist[i]
  253. phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
  254. phones_list.append(phones)
  255. if lang == "zh":
  256. word2ph_list.append(word2ph)
  257. norm_text_list.append(norm_text)
  258. print(word2ph_list)
  259. phones = sum(phones_list, [])
  260. word2ph = sum(word2ph_list, [])
  261. norm_text = ' '.join(norm_text_list)
  262. return phones, word2ph, norm_text
  263. def nonen_get_bert_inf(text, language):
  264. if(language!="auto"):
  265. textlist, langlist = splite_en_inf(text, language)
  266. else:
  267. textlist=[]
  268. langlist=[]
  269. for tmp in LangSegment.getTexts(text):
  270. langlist.append(tmp["lang"])
  271. textlist.append(tmp["text"])
  272. print(textlist)
  273. print(langlist)
  274. bert_list = []
  275. for i in range(len(textlist)):
  276. lang = langlist[i]
  277. phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
  278. bert = get_bert_inf(phones, word2ph, norm_text, lang)
  279. bert_list.append(bert)
  280. bert = torch.cat(bert_list, dim=1)
  281. return bert
  282. splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
  283. def get_first(text):
  284. pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
  285. text = re.split(pattern, text)[0].strip()
  286. return text
  287. def get_cleaned_text_final(text,language):
  288. if language in {"en","all_zh","all_ja"}:
  289. phones, word2ph, norm_text = clean_text_inf(text, language)
  290. elif language in {"zh", "ja","auto"}:
  291. phones, word2ph, norm_text = nonen_clean_text_inf(text, language)
  292. return phones, word2ph, norm_text
  293. def get_bert_final(phones, word2ph, text,language,device):
  294. if language == "en":
  295. bert = get_bert_inf(phones, word2ph, text, language)
  296. elif language in {"zh", "ja","auto"}:
  297. bert = nonen_get_bert_inf(text, language)
  298. elif language == "all_zh":
  299. bert = get_bert_feature(text, word2ph).to(device)
  300. else:
  301. bert = torch.zeros((1024, len(phones))).to(device)
  302. return bert
  303. def merge_short_text_in_array(texts, threshold):
  304. if (len(texts)) < 2:
  305. return texts
  306. result = []
  307. text = ""
  308. for ele in texts:
  309. text += ele
  310. if len(text) >= threshold:
  311. result.append(text)
  312. text = ""
  313. if (len(text) > 0):
  314. if len(result) == 0:
  315. result.append(text)
  316. else:
  317. result[len(result) - 1] += text
  318. return result
  319. def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
  320. if prompt_text is None or len(prompt_text) == 0:
  321. ref_free = True
  322. t0 = ttime()
  323. prompt_language = dict_language[prompt_language]
  324. text_language = dict_language[text_language]
  325. if not ref_free:
  326. prompt_text = prompt_text.strip("\n")
  327. if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
  328. print(i18n("实际输入的参考文本:"), prompt_text)
  329. text = text.strip("\n")
  330. if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
  331. print(i18n("实际输入的目标文本:"), text)
  332. zero_wav = np.zeros(
  333. int(hps.data.sampling_rate * 0.3),
  334. dtype=np.float16 if is_half == True else np.float32,
  335. )
  336. with torch.no_grad():
  337. wav16k, sr = librosa.load(ref_wav_path, sr=16000)
  338. if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
  339. raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
  340. wav16k = torch.from_numpy(wav16k)
  341. zero_wav_torch = torch.from_numpy(zero_wav)
  342. if is_half == True:
  343. wav16k = wav16k.half().to(device)
  344. zero_wav_torch = zero_wav_torch.half().to(device)
  345. else:
  346. wav16k = wav16k.to(device)
  347. zero_wav_torch = zero_wav_torch.to(device)
  348. wav16k = torch.cat([wav16k, zero_wav_torch])
  349. ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
  350. "last_hidden_state"
  351. ].transpose(
  352. 1, 2
  353. ) # .float()
  354. codes = vq_model.extract_latent(ssl_content)
  355. prompt_semantic = codes[0, 0]
  356. t1 = ttime()
  357. if (how_to_cut == i18n("凑四句一切")):
  358. text = cut1(text)
  359. elif (how_to_cut == i18n("凑50字一切")):
  360. text = cut2(text)
  361. elif (how_to_cut == i18n("按中文句号。切")):
  362. text = cut3(text)
  363. elif (how_to_cut == i18n("按英文句号.切")):
  364. text = cut4(text)
  365. elif (how_to_cut == i18n("按标点符号切")):
  366. text = cut5(text)
  367. while "\n\n" in text:
  368. text = text.replace("\n\n", "\n")
  369. print(i18n("实际输入的目标文本(切句后):"), text)
  370. texts = text.split("\n")
  371. texts = merge_short_text_in_array(texts, 5)
  372. audio_opt = []
  373. if not ref_free:
  374. phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
  375. bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
  376. for text in texts:
  377. # 解决输入目标文本的空行导致报错的问题
  378. if (len(text.strip()) == 0):
  379. continue
  380. if (text[-1] not in splits): text += "。" if text_language != "en" else "."
  381. print(i18n("实际输入的目标文本(每句):"), text)
  382. phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
  383. bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
  384. if not ref_free:
  385. bert = torch.cat([bert1, bert2], 1)
  386. all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
  387. else:
  388. bert = bert2
  389. all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
  390. bert = bert.to(device).unsqueeze(0)
  391. all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
  392. prompt = prompt_semantic.unsqueeze(0).to(device)
  393. t2 = ttime()
  394. with torch.no_grad():
  395. # pred_semantic = t2s_model.model.infer(
  396. pred_semantic, idx = t2s_model.model.infer_panel(
  397. all_phoneme_ids,
  398. all_phoneme_len,
  399. None if ref_free else prompt,
  400. bert,
  401. # prompt_phone_len=ph_offset,
  402. top_k=top_k,
  403. top_p=top_p,
  404. temperature=temperature,
  405. early_stop_num=hz * max_sec,
  406. )
  407. t3 = ttime()
  408. # print(pred_semantic.shape,idx)
  409. pred_semantic = pred_semantic[:, -idx:].unsqueeze(
  410. 0
  411. ) # .unsqueeze(0)#mq要多unsqueeze一次
  412. refer = get_spepc(hps, ref_wav_path) # .to(device)
  413. if is_half == True:
  414. refer = refer.half().to(device)
  415. else:
  416. refer = refer.to(device)
  417. # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
  418. audio = (
  419. vq_model.decode(
  420. pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
  421. )
  422. .detach()
  423. .cpu()
  424. .numpy()[0, 0]
  425. ) ###试试重建不带上prompt部分
  426. max_audio=np.abs(audio).max()#简单防止16bit爆音
  427. if max_audio>1:audio/=max_audio
  428. audio_opt.append(audio)
  429. audio_opt.append(zero_wav)
  430. t4 = ttime()
  431. print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
  432. yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
  433. np.int16
  434. )
  435. def split(todo_text):
  436. todo_text = todo_text.replace("……", "。").replace("——", ",")
  437. if todo_text[-1] not in splits:
  438. todo_text += "。"
  439. i_split_head = i_split_tail = 0
  440. len_text = len(todo_text)
  441. todo_texts = []
  442. while 1:
  443. if i_split_head >= len_text:
  444. break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
  445. if todo_text[i_split_head] in splits:
  446. i_split_head += 1
  447. todo_texts.append(todo_text[i_split_tail:i_split_head])
  448. i_split_tail = i_split_head
  449. else:
  450. i_split_head += 1
  451. return todo_texts
  452. def cut1(inp):
  453. inp = inp.strip("\n")
  454. inps = split(inp)
  455. split_idx = list(range(0, len(inps), 4))
  456. split_idx[-1] = None
  457. if len(split_idx) > 1:
  458. opts = []
  459. for idx in range(len(split_idx) - 1):
  460. opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
  461. else:
  462. opts = [inp]
  463. return "\n".join(opts)
  464. def cut2(inp):
  465. inp = inp.strip("\n")
  466. inps = split(inp)
  467. if len(inps) < 2:
  468. return inp
  469. opts = []
  470. summ = 0
  471. tmp_str = ""
  472. for i in range(len(inps)):
  473. summ += len(inps[i])
  474. tmp_str += inps[i]
  475. if summ > 50:
  476. summ = 0
  477. opts.append(tmp_str)
  478. tmp_str = ""
  479. if tmp_str != "":
  480. opts.append(tmp_str)
  481. # print(opts)
  482. if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
  483. opts[-2] = opts[-2] + opts[-1]
  484. opts = opts[:-1]
  485. return "\n".join(opts)
  486. def cut3(inp):
  487. inp = inp.strip("\n")
  488. return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
  489. def cut4(inp):
  490. inp = inp.strip("\n")
  491. return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
  492. # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
  493. def cut5(inp):
  494. # if not re.search(r'[^\w\s]', inp[-1]):
  495. # inp += '。'
  496. inp = inp.strip("\n")
  497. punds = r'[,.;?!、,。?!;:]'
  498. items = re.split(f'({punds})', inp)
  499. items = ["".join(group) for group in zip(items[::2], items[1::2])]
  500. opt = "\n".join(items)
  501. return opt
  502. def custom_sort_key(s):
  503. # 使用正则表达式提取字符串中的数字部分和非数字部分
  504. parts = re.split('(\d+)', s)
  505. # 将数字部分转换为整数,非数字部分保持不变
  506. parts = [int(part) if part.isdigit() else part for part in parts]
  507. return parts
  508. def change_choices():
  509. SoVITS_names, GPT_names = get_weights_names()
  510. return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
  511. pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
  512. pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
  513. SoVITS_weight_root = "SoVITS_weights"
  514. GPT_weight_root = "GPT_weights"
  515. os.makedirs(SoVITS_weight_root, exist_ok=True)
  516. os.makedirs(GPT_weight_root, exist_ok=True)
  517. def get_weights_names():
  518. SoVITS_names = [pretrained_sovits_name]
  519. for name in os.listdir(SoVITS_weight_root):
  520. if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name))
  521. GPT_names = [pretrained_gpt_name]
  522. for name in os.listdir(GPT_weight_root):
  523. if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name))
  524. return SoVITS_names, GPT_names
  525. SoVITS_names, GPT_names = get_weights_names()
  526. with gr.Blocks(title="GPT-SoVITS WebUI") as app:
  527. gr.Markdown(
  528. value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
  529. )
  530. with gr.Group():
  531. gr.Markdown(value=i18n("模型切换"))
  532. with gr.Row():
  533. GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
  534. SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
  535. refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
  536. refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
  537. SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
  538. GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
  539. gr.Markdown(value=i18n("*请上传并填写参考信息"))
  540. with gr.Row():
  541. inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath")
  542. with gr.Column():
  543. ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
  544. gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT"))
  545. prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
  546. prompt_language = gr.Dropdown(
  547. label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
  548. )
  549. gr.Markdown(value=i18n("*请填写需要合成的目标文本。中英混合选中文,日英混合选日文,中日混合暂不支持,非目标语言文本自动遗弃。"))
  550. with gr.Row():
  551. text = gr.Textbox(label=i18n("需要合成的文本"), value="")
  552. text_language = gr.Dropdown(
  553. label=i18n("需要合成的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
  554. )
  555. how_to_cut = gr.Radio(
  556. label=i18n("怎么切"),
  557. choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
  558. value=i18n("凑四句一切"),
  559. interactive=True,
  560. )
  561. with gr.Row():
  562. gr.Markdown("gpt采样参数(无参考文本时不要太低):")
  563. top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
  564. top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
  565. temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
  566. inference_button = gr.Button(i18n("合成语音"), variant="primary")
  567. output = gr.Audio(label=i18n("输出的语音"))
  568. inference_button.click(
  569. get_tts_wav,
  570. [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
  571. [output],
  572. )
  573. gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
  574. with gr.Row():
  575. text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
  576. button1 = gr.Button(i18n("凑四句一切"), variant="primary")
  577. button2 = gr.Button(i18n("凑50字一切"), variant="primary")
  578. button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
  579. button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
  580. button5 = gr.Button(i18n("按标点符号切"), variant="primary")
  581. text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
  582. button1.click(cut1, [text_inp], [text_opt])
  583. button2.click(cut2, [text_inp], [text_opt])
  584. button3.click(cut3, [text_inp], [text_opt])
  585. button4.click(cut4, [text_inp], [text_opt])
  586. button5.click(cut5, [text_inp], [text_opt])
  587. gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
  588. app.queue(concurrency_count=511, max_size=1022).launch(
  589. server_name="0.0.0.0",
  590. inbrowser=True,
  591. share=is_share,
  592. server_port=infer_ttswebui,
  593. quiet=True,
  594. )