api.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. """
  2. # api.py usage
  3. ` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" `
  4. ## 执行参数:
  5. `-s` - `SoVITS模型路径, 可在 config.py 中指定`
  6. `-g` - `GPT模型路径, 可在 config.py 中指定`
  7. 调用请求缺少参考音频时使用
  8. `-dr` - `默认参考音频路径`
  9. `-dt` - `默认参考音频文本`
  10. `-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
  11. `-d` - `推理设备, "cuda","cpu","mps"`
  12. `-a` - `绑定地址, 默认"127.0.0.1"`
  13. `-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
  14. `-fp` - `覆盖 config.py 使用全精度`
  15. `-hp` - `覆盖 config.py 使用半精度`
  16. `-hb` - `cnhubert路径`
  17. `-b` - `bert路径`
  18. ## 调用:
  19. ### 推理
  20. endpoint: `/`
  21. 使用执行参数指定的参考音频:
  22. GET:
  23. `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
  24. POST:
  25. ```json
  26. {
  27. "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
  28. "text_language": "zh"
  29. }
  30. ```
  31. 手动指定当次推理所使用的参考音频:
  32. GET:
  33. `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
  34. POST:
  35. ```json
  36. {
  37. "refer_wav_path": "123.wav",
  38. "prompt_text": "一二三。",
  39. "prompt_language": "zh",
  40. "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
  41. "text_language": "zh"
  42. }
  43. ```
  44. RESP:
  45. 成功: 直接返回 wav 音频流, http code 200
  46. 失败: 返回包含错误信息的 json, http code 400
  47. ### 更换默认参考音频
  48. endpoint: `/change_refer`
  49. key与推理端一样
  50. GET:
  51. `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh`
  52. POST:
  53. ```json
  54. {
  55. "refer_wav_path": "123.wav",
  56. "prompt_text": "一二三。",
  57. "prompt_language": "zh"
  58. }
  59. ```
  60. RESP:
  61. 成功: json, http code 200
  62. 失败: json, 400
  63. ### 命令控制
  64. endpoint: `/control`
  65. command:
  66. "restart": 重新运行
  67. "exit": 结束运行
  68. GET:
  69. `http://127.0.0.1:9880/control?command=restart`
  70. POST:
  71. ```json
  72. {
  73. "command": "restart"
  74. }
  75. ```
  76. RESP: 无
  77. """
  78. import argparse
  79. import os
  80. import sys
  81. now_dir = os.getcwd()
  82. sys.path.append(now_dir)
  83. sys.path.append("%s/GPT_SoVITS" % (now_dir))
  84. import signal
  85. from time import time as ttime
  86. import torch
  87. import librosa
  88. import soundfile as sf
  89. from fastapi import FastAPI, Request, HTTPException
  90. from fastapi.responses import StreamingResponse, JSONResponse
  91. import uvicorn
  92. from transformers import AutoModelForMaskedLM, AutoTokenizer
  93. import numpy as np
  94. from feature_extractor import cnhubert
  95. from io import BytesIO
  96. from module.models import SynthesizerTrn
  97. from AR.models.t2s_lightning_module import Text2SemanticLightningModule
  98. from text import cleaned_text_to_sequence
  99. from text.cleaner import clean_text
  100. from module.mel_processing import spectrogram_torch
  101. from my_utils import load_audio
  102. import config as global_config
  103. g_config = global_config.Config()
  104. # AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
  105. parser = argparse.ArgumentParser(description="GPT-SoVITS api")
  106. parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
  107. parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
  108. parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
  109. parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
  110. parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
  111. parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps")
  112. parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
  113. parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
  114. parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
  115. parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
  116. # bool值的用法为 `python ./api.py -fp ...`
  117. # 此时 full_precision==True, half_precision==False
  118. parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
  119. parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
  120. args = parser.parse_args()
  121. sovits_path = args.sovits_path
  122. gpt_path = args.gpt_path
  123. class DefaultRefer:
  124. def __init__(self, path, text, language):
  125. self.path = args.default_refer_path
  126. self.text = args.default_refer_text
  127. self.language = args.default_refer_language
  128. def is_ready(self) -> bool:
  129. return is_full(self.path, self.text, self.language)
  130. default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
  131. device = args.device
  132. port = args.port
  133. host = args.bind_addr
  134. if sovits_path == "":
  135. sovits_path = g_config.pretrained_sovits_path
  136. print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
  137. if gpt_path == "":
  138. gpt_path = g_config.pretrained_gpt_path
  139. print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
  140. # 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
  141. if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
  142. default_refer.path, default_refer.text, default_refer.language = "", "", ""
  143. print("[INFO] 未指定默认参考音频")
  144. else:
  145. print(f"[INFO] 默认参考音频路径: {default_refer.path}")
  146. print(f"[INFO] 默认参考音频文本: {default_refer.text}")
  147. print(f"[INFO] 默认参考音频语种: {default_refer.language}")
  148. is_half = g_config.is_half
  149. if args.full_precision:
  150. is_half = False
  151. if args.half_precision:
  152. is_half = True
  153. if args.full_precision and args.half_precision:
  154. is_half = g_config.is_half # 炒饭fallback
  155. print(f"[INFO] 半精: {is_half}")
  156. cnhubert_base_path = args.hubert_path
  157. bert_path = args.bert_path
  158. cnhubert.cnhubert_base_path = cnhubert_base_path
  159. tokenizer = AutoTokenizer.from_pretrained(bert_path)
  160. bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
  161. if is_half:
  162. bert_model = bert_model.half().to(device)
  163. else:
  164. bert_model = bert_model.to(device)
  165. def is_empty(*items): # 任意一项不为空返回False
  166. for item in items:
  167. if item is not None and item != "":
  168. return False
  169. return True
  170. def is_full(*items): # 任意一项为空返回False
  171. for item in items:
  172. if item is None or item == "":
  173. return False
  174. return True
  175. def get_bert_feature(text, word2ph):
  176. with torch.no_grad():
  177. inputs = tokenizer(text, return_tensors="pt")
  178. for i in inputs:
  179. inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
  180. res = bert_model(**inputs, output_hidden_states=True)
  181. res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
  182. assert len(word2ph) == len(text)
  183. phone_level_feature = []
  184. for i in range(len(word2ph)):
  185. repeat_feature = res[i].repeat(word2ph[i], 1)
  186. phone_level_feature.append(repeat_feature)
  187. phone_level_feature = torch.cat(phone_level_feature, dim=0)
  188. # if(is_half==True):phone_level_feature=phone_level_feature.half()
  189. return phone_level_feature.T
  190. n_semantic = 1024
  191. dict_s2 = torch.load(sovits_path, map_location="cpu")
  192. hps = dict_s2["config"]
  193. class DictToAttrRecursive:
  194. def __init__(self, input_dict):
  195. for key, value in input_dict.items():
  196. if isinstance(value, dict):
  197. # 如果值是字典,递归调用构造函数
  198. setattr(self, key, DictToAttrRecursive(value))
  199. else:
  200. setattr(self, key, value)
  201. hps = DictToAttrRecursive(hps)
  202. hps.model.semantic_frame_rate = "25hz"
  203. dict_s1 = torch.load(gpt_path, map_location="cpu")
  204. config = dict_s1["config"]
  205. ssl_model = cnhubert.get_model()
  206. if is_half:
  207. ssl_model = ssl_model.half().to(device)
  208. else:
  209. ssl_model = ssl_model.to(device)
  210. vq_model = SynthesizerTrn(
  211. hps.data.filter_length // 2 + 1,
  212. hps.train.segment_size // hps.data.hop_length,
  213. n_speakers=hps.data.n_speakers,
  214. **hps.model)
  215. if is_half:
  216. vq_model = vq_model.half().to(device)
  217. else:
  218. vq_model = vq_model.to(device)
  219. vq_model.eval()
  220. print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
  221. hz = 50
  222. max_sec = config['data']['max_sec']
  223. t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
  224. t2s_model.load_state_dict(dict_s1["weight"])
  225. if is_half:
  226. t2s_model = t2s_model.half()
  227. t2s_model = t2s_model.to(device)
  228. t2s_model.eval()
  229. total = sum([param.nelement() for param in t2s_model.parameters()])
  230. print("Number of parameter: %.2fM" % (total / 1e6))
  231. def get_spepc(hps, filename):
  232. audio = load_audio(filename, int(hps.data.sampling_rate))
  233. audio = torch.FloatTensor(audio)
  234. audio_norm = audio
  235. audio_norm = audio_norm.unsqueeze(0)
  236. spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
  237. hps.data.win_length, center=False)
  238. return spec
  239. dict_language = {
  240. "中文": "zh",
  241. "英文": "en",
  242. "日文": "ja",
  243. "ZH": "zh",
  244. "EN": "en",
  245. "JA": "ja",
  246. "zh": "zh",
  247. "en": "en",
  248. "ja": "ja"
  249. }
  250. def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
  251. t0 = ttime()
  252. prompt_text = prompt_text.strip("\n")
  253. prompt_language, text = prompt_language, text.strip("\n")
  254. zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
  255. with torch.no_grad():
  256. wav16k, sr = librosa.load(ref_wav_path, sr=16000)
  257. wav16k = torch.from_numpy(wav16k)
  258. zero_wav_torch = torch.from_numpy(zero_wav)
  259. if (is_half == True):
  260. wav16k = wav16k.half().to(device)
  261. zero_wav_torch = zero_wav_torch.half().to(device)
  262. else:
  263. wav16k = wav16k.to(device)
  264. zero_wav_torch = zero_wav_torch.to(device)
  265. wav16k = torch.cat([wav16k, zero_wav_torch])
  266. ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
  267. codes = vq_model.extract_latent(ssl_content)
  268. prompt_semantic = codes[0, 0]
  269. t1 = ttime()
  270. prompt_language = dict_language[prompt_language]
  271. text_language = dict_language[text_language]
  272. phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
  273. phones1 = cleaned_text_to_sequence(phones1)
  274. texts = text.split("\n")
  275. audio_opt = []
  276. for text in texts:
  277. phones2, word2ph2, norm_text2 = clean_text(text, text_language)
  278. phones2 = cleaned_text_to_sequence(phones2)
  279. if (prompt_language == "zh"):
  280. bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
  281. else:
  282. bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
  283. device)
  284. if (text_language == "zh"):
  285. bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
  286. else:
  287. bert2 = torch.zeros((1024, len(phones2))).to(bert1)
  288. bert = torch.cat([bert1, bert2], 1)
  289. all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
  290. bert = bert.to(device).unsqueeze(0)
  291. all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
  292. prompt = prompt_semantic.unsqueeze(0).to(device)
  293. t2 = ttime()
  294. with torch.no_grad():
  295. # pred_semantic = t2s_model.model.infer(
  296. pred_semantic, idx = t2s_model.model.infer_panel(
  297. all_phoneme_ids,
  298. all_phoneme_len,
  299. prompt,
  300. bert,
  301. # prompt_phone_len=ph_offset,
  302. top_k=config['inference']['top_k'],
  303. early_stop_num=hz * max_sec)
  304. t3 = ttime()
  305. # print(pred_semantic.shape,idx)
  306. pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
  307. refer = get_spepc(hps, ref_wav_path) # .to(device)
  308. if (is_half == True):
  309. refer = refer.half().to(device)
  310. else:
  311. refer = refer.to(device)
  312. # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
  313. audio = \
  314. vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
  315. refer).detach().cpu().numpy()[
  316. 0, 0] ###试试重建不带上prompt部分
  317. audio_opt.append(audio)
  318. audio_opt.append(zero_wav)
  319. t4 = ttime()
  320. print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
  321. yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
  322. def handle_control(command):
  323. if command == "restart":
  324. os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
  325. elif command == "exit":
  326. os.kill(os.getpid(), signal.SIGTERM)
  327. exit(0)
  328. def handle_change(path, text, language):
  329. if is_empty(path, text, language):
  330. return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
  331. if path != "" or path is not None:
  332. default_refer.path = path
  333. if text != "" or text is not None:
  334. default_refer.text = text
  335. if language != "" or language is not None:
  336. default_refer.language = language
  337. print(f"[INFO] 当前默认参考音频路径: {default_refer.path}")
  338. print(f"[INFO] 当前默认参考音频文本: {default_refer.text}")
  339. print(f"[INFO] 当前默认参考音频语种: {default_refer.language}")
  340. print(f"[INFO] is_ready: {default_refer.is_ready()}")
  341. return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
  342. def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
  343. if (
  344. refer_wav_path == "" or refer_wav_path is None
  345. or prompt_text == "" or prompt_text is None
  346. or prompt_language == "" or prompt_language is None
  347. ):
  348. refer_wav_path, prompt_text, prompt_language = (
  349. default_refer.path,
  350. default_refer.text,
  351. default_refer.language,
  352. )
  353. if not default_refer.is_ready():
  354. return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
  355. with torch.no_grad():
  356. gen = get_tts_wav(
  357. refer_wav_path, prompt_text, prompt_language, text, text_language
  358. )
  359. sampling_rate, audio_data = next(gen)
  360. wav = BytesIO()
  361. sf.write(wav, audio_data, sampling_rate, format="wav")
  362. wav.seek(0)
  363. torch.cuda.empty_cache()
  364. if device == "mps":
  365. print('executed torch.mps.empty_cache()')
  366. torch.mps.empty_cache()
  367. return StreamingResponse(wav, media_type="audio/wav")
  368. app = FastAPI()
  369. @app.post("/control")
  370. async def control(request: Request):
  371. json_post_raw = await request.json()
  372. return handle_control(json_post_raw.get("command"))
  373. @app.get("/control")
  374. async def control(command: str = None):
  375. return handle_control(command)
  376. @app.post("/change_refer")
  377. async def change_refer(request: Request):
  378. json_post_raw = await request.json()
  379. return handle_change(
  380. json_post_raw.get("refer_wav_path"),
  381. json_post_raw.get("prompt_text"),
  382. json_post_raw.get("prompt_language")
  383. )
  384. @app.get("/change_refer")
  385. async def change_refer(
  386. refer_wav_path: str = None,
  387. prompt_text: str = None,
  388. prompt_language: str = None
  389. ):
  390. return handle_change(refer_wav_path, prompt_text, prompt_language)
  391. @app.post("/")
  392. async def tts_endpoint(request: Request):
  393. json_post_raw = await request.json()
  394. return handle(
  395. json_post_raw.get("refer_wav_path"),
  396. json_post_raw.get("prompt_text"),
  397. json_post_raw.get("prompt_language"),
  398. json_post_raw.get("text"),
  399. json_post_raw.get("text_language"),
  400. )
  401. @app.get("/")
  402. async def tts_endpoint(
  403. refer_wav_path: str = None,
  404. prompt_text: str = None,
  405. prompt_language: str = None,
  406. text: str = None,
  407. text_language: str = None,
  408. ):
  409. return handle(refer_wav_path, prompt_text, prompt_language, text, text_language)
  410. if __name__ == "__main__":
  411. uvicorn.run(app, host=host, port=port, workers=1)