2
0

gen_tortoise.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import os
  2. import numpy as np
  3. import gradio as gr
  4. from scipy.io.wavfile import write as write_wav
  5. from tts_webui.tortoise.save_json import save_json
  6. from tts_webui.bark.split_text_functions import split_by_lines
  7. from tts_webui.utils.create_base_filename import create_base_filename
  8. from tts_webui.utils.date import get_date_string
  9. from tts_webui.utils.save_waveform_plot import middleware_save_waveform_plot
  10. from tts_webui.tortoise.TortoiseParameters import TortoiseParameters
  11. from tts_webui.utils.get_path_from_root import get_path_from_root
  12. from tts_webui.utils.torch_clear_memory import torch_clear_memory
  13. SAMPLE_RATE = 24_000
  14. OUTPUT_PATH = "outputs/"
  15. MODEL = None
  16. TORTOISE_VOICE_DIR = "voices-tortoise"
  17. TORTOISE_VOICE_DIR_ABS = get_path_from_root("voices-tortoise")
  18. TORTOISE_LOCAL_MODELS_DIR = get_path_from_root("data", "models", "tortoise")
  19. class TortoiseOutputUpdate:
  20. def __init__(
  21. self,
  22. audio,
  23. bundle_name,
  24. params,
  25. ):
  26. self.audio = audio
  27. self.bundle_name = bundle_name
  28. self.params = params
  29. def get_model_list():
  30. try:
  31. return ["Default"] + [
  32. x for x in os.listdir(TORTOISE_LOCAL_MODELS_DIR) if x != ".gitkeep"
  33. ]
  34. except FileNotFoundError as e:
  35. print(e)
  36. return ["Default"]
  37. def get_full_model_dir(model_dir: str):
  38. return os.path.join(TORTOISE_LOCAL_MODELS_DIR, model_dir)
  39. def switch_model(
  40. model_dir: str,
  41. kv_cache=False,
  42. use_deepspeed=False,
  43. half=False,
  44. tokenizer=None,
  45. use_basic_cleaners=False,
  46. ):
  47. from tortoise.api import MODELS_DIR
  48. get_tts(
  49. models_dir=(
  50. MODELS_DIR if model_dir == "Default" else get_full_model_dir(model_dir)
  51. ),
  52. force_reload=True,
  53. kv_cache=kv_cache,
  54. use_deepspeed=use_deepspeed,
  55. half=half,
  56. tokenizer_path=tokenizer.name if tokenizer else None,
  57. tokenizer_basic=use_basic_cleaners,
  58. )
  59. return gr.Dropdown()
  60. def get_voice_list():
  61. from tortoise.utils.audio import get_voices
  62. return ["random"] + list(get_voices(extra_voice_dirs=[TORTOISE_VOICE_DIR]))
  63. def save_wav_tortoise(audio_array, filename):
  64. write_wav(filename, SAMPLE_RATE, audio_array)
  65. def unload_tortoise_model():
  66. global MODEL
  67. if MODEL is not None:
  68. del MODEL
  69. torch_clear_memory()
  70. MODEL = None
  71. def get_tts(
  72. models_dir=None,
  73. force_reload=False,
  74. kv_cache=False,
  75. use_deepspeed=False,
  76. half=False,
  77. device=None,
  78. tokenizer_path=None,
  79. tokenizer_basic=False,
  80. ):
  81. from tortoise.api import MODELS_DIR, TextToSpeech
  82. if models_dir is None:
  83. models_dir = MODELS_DIR
  84. global MODEL
  85. if MODEL is None or force_reload:
  86. print("Loading tortoise model: ", models_dir)
  87. print("Clearing memory...")
  88. unload_tortoise_model()
  89. print("Memory cleared")
  90. print("Loading model...")
  91. MODEL = TextToSpeech(
  92. models_dir=models_dir,
  93. kv_cache=kv_cache,
  94. use_deepspeed=use_deepspeed,
  95. half=half,
  96. device=device,
  97. tokenizer_vocab_file=tokenizer_path,
  98. tokenizer_basic=tokenizer_basic,
  99. )
  100. print("Model loaded")
  101. return MODEL
  102. last_voices = None
  103. voice_samples = None
  104. conditioning_latents = None
  105. def get_voices_cached(voice):
  106. from tortoise.utils.audio import load_voices
  107. global last_voices, voice_samples, conditioning_latents
  108. if voice == last_voices:
  109. last_voices = voice
  110. return voice_samples, conditioning_latents
  111. voices = voice.split("&") if "&" in voice else [voice]
  112. voice_samples, conditioning_latents = load_voices(
  113. voices, extra_voice_dirs=[TORTOISE_VOICE_DIR]
  114. )
  115. last_voices = voices
  116. return voice_samples, conditioning_latents
  117. def generate_tortoise(
  118. params: TortoiseParameters,
  119. text: str,
  120. candidates: int,
  121. ):
  122. os.makedirs(OUTPUT_PATH, exist_ok=True)
  123. voice_samples, conditioning_latents = get_voices_cached(params.voice)
  124. tts = get_tts()
  125. result, state = tts.tts_with_preset(
  126. text,
  127. return_deterministic_state=True,
  128. k=candidates,
  129. voice_samples=voice_samples,
  130. conditioning_latents=conditioning_latents,
  131. use_deterministic_seed=get_seed(params),
  132. **{
  133. k: v
  134. for k, v in params.to_dict().items()
  135. if k not in ["text", "voice", "split_prompt", "seed", "model", "name"]
  136. },
  137. )
  138. seed, _, _, _ = state
  139. params.seed = seed # type: ignore
  140. gen_list = result if isinstance(result, list) else [result]
  141. audio_arrays = [tensor_to_audio_array(x) for x in gen_list]
  142. return [
  143. _process_gen(candidates, audio_array, id, params)
  144. for id, audio_array in enumerate(audio_arrays)
  145. ]
  146. def get_seed(params):
  147. return params.seed if params.seed != -1 else None
  148. def _process_gen(candidates, audio_array, id, params: TortoiseParameters):
  149. model = "tortoise"
  150. date = get_date_string()
  151. name = params.name or params.voice
  152. filename, filename_png, filename_json = get_filenames(
  153. create_base_filename_tortoise(name, id, model, date)
  154. )
  155. save_wav_tortoise(audio_array, filename)
  156. middleware_save_waveform_plot(audio_array, filename_png)
  157. metadata = {
  158. "_version": "0.0.1",
  159. "_type": model,
  160. "date": date,
  161. "candidates": candidates,
  162. "index": id if isinstance(id, int) else 0,
  163. **params.to_metadata(),
  164. }
  165. save_json(filename_json, metadata)
  166. folder_root = os.path.dirname(filename)
  167. return TortoiseOutputUpdate(
  168. audio=(SAMPLE_RATE, audio_array),
  169. bundle_name=folder_root,
  170. params=gr.JSON(value=metadata), # broken because gradio returns only __type__
  171. )
  172. def create_base_filename_tortoise(name, j, model, date):
  173. return f"{create_base_filename(f'{name}__n{j}', OUTPUT_PATH, model, date)}"
  174. def tensor_to_audio_array(gen):
  175. return gen.squeeze(0).cpu().t().numpy()
  176. def get_filenames(base_filename):
  177. filename = f"{base_filename}.wav"
  178. filename_png = f"{base_filename}.png"
  179. filename_json = f"{base_filename}.json"
  180. return filename, filename_png, filename_json
  181. def generate_tortoise_long(count: int, params: TortoiseParameters):
  182. print("Generating tortoise with params:")
  183. print(params)
  184. prompt_raw = params.text
  185. split_prompt = params.split_prompt
  186. prompts = split_by_lines(prompt_raw) if split_prompt else [prompt_raw]
  187. audio_pieces = [[] for _ in range(count)]
  188. for prompt in prompts:
  189. datas = generate_tortoise(
  190. params,
  191. text=prompt,
  192. candidates=count,
  193. )
  194. for data in datas:
  195. yield [data.audio, data.bundle_name, data.params]
  196. for i in range(count):
  197. audio_array = datas[i].audio[1]
  198. audio_pieces[i].append(audio_array)
  199. # if there is only one prompt, then we don't need to concatenate
  200. if len(prompts) == 1:
  201. # return [None, None, None]
  202. return {}
  203. for i in range(count):
  204. res = _process_gen(
  205. count, np.concatenate(audio_pieces[i]), id=f"_long_{str(i)}", params=params
  206. )
  207. yield [res.audio, res.bundle_name, res.params]
  208. # return [None, None, None]
  209. return {}