bark_tab.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. import os
  2. import shutil
  3. import numpy as np
  4. import gradio as gr
  5. from tts_webui.config.config import config
  6. # from tts_webui.bark.clone.tab_voice_clone import tab_voice_clone
  7. from tts_webui.history_tab.voices_tab import voices_tab
  8. from tts_webui.bark.settings_tab_bark import settings_tab_bark
  9. from tts_webui.bark.get_speaker_gender import get_speaker_gender
  10. from tts_webui.bark.npz_tools import get_npz_files, save_npz
  11. from tts_webui.bark.split_text_functions import split_by_length_simple, split_by_lines
  12. from tts_webui.bark.generation_settings import (
  13. PromptSplitSettings,
  14. LongPromptHistorySettings,
  15. )
  16. from tts_webui.history_tab.save_to_favorites import save_to_favorites
  17. from tts_webui.utils.save_json_result import save_json_result
  18. from tts_webui.utils.get_dict_props import get_dict_props
  19. from tts_webui.utils.randomize_seed import randomize_seed_ui
  20. from tts_webui.decorators.gradio_dict_decorator import dictionarize
  21. from tts_webui.decorators.decorator_apply_torch_seed import (
  22. decorator_apply_torch_seed_generator,
  23. )
  24. from tts_webui.decorators.decorator_log_generation import (
  25. decorator_log_generation_generator,
  26. )
  27. from tts_webui.decorators.decorator_save_wav import (
  28. decorator_save_wav_generator,
  29. )
  30. from tts_webui.decorators.decorator_add_base_filename import (
  31. decorator_add_base_filename_generator,
  32. format_date_for_file,
  33. )
  34. from tts_webui.decorators.decorator_add_date import (
  35. decorator_add_date_generator,
  36. )
  37. from tts_webui.decorators.decorator_add_model_type import (
  38. decorator_add_model_type_generator,
  39. )
  40. from tts_webui.decorators.log_function_time import log_generator_time
  41. from tts_webui.extensions_loader.decorator_extensions import (
  42. decorator_extension_inner_generator,
  43. decorator_extension_outer_generator,
  44. )
  45. from tts_webui.utils.outputs.path import get_relative_output_path_ext
  46. # from bark import SAMPLE_RATE
  47. SAMPLE_RATE = 24_000
  48. # from bark.generation import SUPPORTED_LANGS
  49. SUPPORTED_LANGS = [
  50. ("English", "en"),
  51. ("German", "de"),
  52. ("Spanish", "es"),
  53. ("French", "fr"),
  54. ("Hindi", "hi"),
  55. ("Italian", "it"),
  56. ("Japanese", "ja"),
  57. ("Korean", "ko"),
  58. ("Polish", "pl"),
  59. ("Portuguese", "pt"),
  60. ("Russian", "ru"),
  61. ("Turkish", "tr"),
  62. ("Chinese", "zh"),
  63. ]
  64. def _decorator_bark_save_metadata_generator(fn):
  65. def _save_metadata_and_npz(kwargs, result_dict):
  66. from tts_webui.bark.generate_and_save_metadata import generate_bark_metadata
  67. metadata = generate_bark_metadata(
  68. date=format_date_for_file(result_dict["date"]),
  69. full_generation=result_dict["full_generation"],
  70. params=kwargs,
  71. )
  72. save_json_result(result_dict, metadata)
  73. npz_path = get_relative_output_path_ext(result_dict, ".npz")
  74. save_npz(
  75. filename=npz_path,
  76. full_generation=result_dict["full_generation"],
  77. metadata=metadata,
  78. )
  79. result_dict["metadata"] = metadata
  80. result_dict["npz_path"] = npz_path
  81. return result_dict
  82. def wrapper(*args, **kwargs):
  83. for result_dict in fn(*args, **kwargs):
  84. if result_dict is None:
  85. continue
  86. yield _save_metadata_and_npz(kwargs, result_dict)
  87. return wrapper
  88. def _bark_get_prompts(text, long_prompt_radio):
  89. if text is None or text == "":
  90. raise ValueError("Prompt is empty")
  91. if long_prompt_radio == PromptSplitSettings.NONE:
  92. prompts = [text]
  93. else:
  94. prompts = (
  95. split_by_lines(text)
  96. if long_prompt_radio == PromptSplitSettings.LINES
  97. else split_by_length_simple(text)
  98. )
  99. return prompts
  100. def _get_long_gen_history_prompt(
  101. long_prompt_history_radio,
  102. last_generation,
  103. original_history_prompt,
  104. ):
  105. switcher = {
  106. LongPromptHistorySettings.CONTINUE: last_generation,
  107. LongPromptHistorySettings.CONSTANT: original_history_prompt,
  108. LongPromptHistorySettings.EMPTY: None,
  109. }
  110. return switcher.get(long_prompt_history_radio, None)
  111. @decorator_add_model_type_generator("bark")
  112. @decorator_extension_outer_generator
  113. @decorator_apply_torch_seed_generator
  114. @_decorator_bark_save_metadata_generator
  115. @decorator_save_wav_generator
  116. @decorator_add_base_filename_generator
  117. @decorator_add_date_generator
  118. @decorator_log_generation_generator
  119. @decorator_extension_inner_generator
  120. @log_generator_time
  121. def bark_generate_long(
  122. text,
  123. history_prompt,
  124. long_prompt_radio,
  125. long_prompt_history_radio,
  126. **kwargs,
  127. ):
  128. from tts_webui.bark.extended_generate import custom_generate_audio
  129. from tts_webui.bark.BarkModelManager import bark_model_manager
  130. pieces = []
  131. original_history_prompt = history_prompt
  132. last_generation = history_prompt
  133. for prompt_piece in _bark_get_prompts(text, long_prompt_radio):
  134. history_prompt = _get_long_gen_history_prompt(
  135. long_prompt_history_radio,
  136. last_generation,
  137. original_history_prompt,
  138. )
  139. if not bark_model_manager.models_loaded:
  140. bark_model_manager.reload_models(config)
  141. full_generation, audio_array = custom_generate_audio(
  142. text=prompt_piece,
  143. history_prompt=history_prompt,
  144. **get_dict_props(
  145. kwargs,
  146. [
  147. "burn_in_prompt",
  148. "history_prompt_semantic",
  149. "text_temp",
  150. "waveform_temp",
  151. "max_length",
  152. ],
  153. ),
  154. output_full=True,
  155. )
  156. last_generation = full_generation
  157. pieces += [audio_array]
  158. yield {
  159. "audio_out": (SAMPLE_RATE, audio_array),
  160. "full_generation": full_generation,
  161. }
  162. if len(pieces) == 1:
  163. return
  164. yield {
  165. "audio_out": (SAMPLE_RATE, np.concatenate(pieces)),
  166. "full_generation": full_generation,
  167. "long": True,
  168. }
  169. return
  170. def unload_models():
  171. from tts_webui.bark.BarkModelManager import bark_model_manager
  172. bark_model_manager.unload_models()
  173. return gr.Button(value="Unloaded")
  174. def bark_tab():
  175. with gr.Tab(label="Bark", id="generation_bark"):
  176. with gr.Tabs():
  177. with gr.Tab("Generation"):
  178. bark_ui()
  179. # tab_voice_clone()
  180. voices_tab()
  181. settings_tab_bark()
  182. def _npz_dropdown_ui(label):
  183. npz_dropdown = gr.Dropdown(
  184. label=label,
  185. choices=get_npz_files() + [""], # type: ignore
  186. type="value",
  187. value=None,
  188. allow_custom_value=True,
  189. show_label=True,
  190. )
  191. btn_style = {
  192. "size": "sm",
  193. "elem_classes": "btn-sm material-symbols-outlined",
  194. }
  195. gr.Button("save", **btn_style).click( # type: ignore
  196. fn=lambda x: [
  197. shutil.copy(x, os.path.join("voices", os.path.basename(x))),
  198. ],
  199. inputs=[npz_dropdown],
  200. )
  201. gr.Button("refresh", **btn_style).click( # type: ignore
  202. fn=lambda: gr.Dropdown(choices=get_npz_files() + [""]), # type: ignore
  203. outputs=npz_dropdown,
  204. api_name=f"reload_old_generation_dropdown{ '' if label == 'Audio Voice' else '_semantic'}",
  205. )
  206. gr.Button("clear", **btn_style).click( # type: ignore
  207. fn=lambda: gr.Dropdown(value=None),
  208. outputs=npz_dropdown,
  209. )
  210. return npz_dropdown
  211. def _voice_select_ui(history_prompt):
  212. with gr.Row():
  213. use_v2 = gr.Checkbox(label="Use V2", value=False)
  214. choice_string = gr.Markdown(
  215. "Chosen voice: en_speaker_0, Gender: Unknown",
  216. )
  217. language = gr.Radio(
  218. [lang[0] for lang in SUPPORTED_LANGS],
  219. type="index",
  220. show_label=False,
  221. value="English",
  222. )
  223. speaker_id = gr.Radio(
  224. ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
  225. type="value",
  226. label="Speaker ID",
  227. value="0",
  228. )
  229. voice_inputs = [language, speaker_id, use_v2]
  230. def create_voice_string_lazy(language, speaker_id, use_v2):
  231. from tts_webui.bark.create_voice_string import create_voice_string
  232. return create_voice_string(language, speaker_id, use_v2)
  233. for i in voice_inputs:
  234. i.change(
  235. fn=create_voice_string_lazy,
  236. inputs=voice_inputs,
  237. outputs=[history_prompt],
  238. ).then(
  239. fn=lambda x: f"Chosen voice: {x}, Gender: {get_speaker_gender(x)}",
  240. inputs=history_prompt,
  241. outputs=[choice_string],
  242. )
  243. def bark_ui():
  244. with gr.Row():
  245. unload_models_button = gr.Button("Unload models")
  246. unload_models_button.click(
  247. fn=unload_models,
  248. outputs=[unload_models_button],
  249. )
  250. with gr.Row():
  251. history_prompt_semantic = _npz_dropdown_ui("Semantic Voice (Optional)")
  252. with gr.Row():
  253. history_prompt = _npz_dropdown_ui("Audio Voice")
  254. with gr.Column():
  255. _voice_select_ui(history_prompt)
  256. with gr.Row():
  257. with gr.Column():
  258. long_prompt_radio = gr.Radio(
  259. PromptSplitSettings.choices, # type: ignore
  260. type="value",
  261. label="Prompt type",
  262. value=PromptSplitSettings.NONE,
  263. show_label=False,
  264. )
  265. long_prompt_history_radio = gr.Radio(
  266. LongPromptHistorySettings.choices, # type: ignore
  267. type="value",
  268. label="For each subsequent generation:",
  269. value=LongPromptHistorySettings.CONTINUE,
  270. )
  271. max_length = gr.Slider(
  272. label="Max length",
  273. value=15,
  274. minimum=0.1,
  275. maximum=18,
  276. step=0.1,
  277. )
  278. with gr.Column():
  279. # TODO: Add gradient temperature options (requires model changes)
  280. text_temp = gr.Slider(
  281. label="Text temperature",
  282. value=0.7,
  283. minimum=0.0,
  284. maximum=1.2,
  285. step=0.05,
  286. )
  287. waveform_temp = gr.Slider(
  288. label="Waveform temperature",
  289. value=0.7,
  290. minimum=0.0,
  291. maximum=1.2,
  292. step=0.05,
  293. )
  294. with gr.Column():
  295. seed, randomize_seed_callback = randomize_seed_ui()
  296. burn_in_prompt = gr.Textbox(
  297. label="Burn In Prompt (Optional)", lines=3, placeholder="Enter text here..."
  298. )
  299. text = gr.Textbox(label="Prompt", lines=3, placeholder="Enter text here...")
  300. with gr.Column():
  301. audio = gr.Audio(type="filepath", label="Generated audio")
  302. with gr.Row():
  303. save_button = gr.Button("Save", size="sm")
  304. continue_button = gr.Button("Use as history", size="sm")
  305. continue_semantic_button = gr.Button("Use as semantic history", size="sm")
  306. full_generation = gr.Textbox(visible=False)
  307. metadata = gr.JSON(visible=False)
  308. folder_root = gr.Textbox(visible=False)
  309. save_button.click(
  310. fn=save_to_favorites,
  311. inputs=[folder_root],
  312. outputs=[save_button],
  313. api_name="bark_favorite",
  314. )
  315. continue_button.click(
  316. fn=lambda x: x,
  317. inputs=[full_generation],
  318. outputs=[history_prompt],
  319. )
  320. continue_semantic_button.click(
  321. fn=lambda x: x,
  322. inputs=[full_generation],
  323. outputs=[history_prompt_semantic],
  324. )
  325. # fix the bug where selecting No history does not work with burn in prompt
  326. input_dict = {
  327. seed: "seed",
  328. text: "text",
  329. burn_in_prompt: "burn_in_prompt",
  330. text_temp: "text_temp",
  331. waveform_temp: "waveform_temp",
  332. max_length: "max_length",
  333. history_prompt: "history_prompt",
  334. history_prompt_semantic: "history_prompt_semantic",
  335. long_prompt_radio: "long_prompt_radio",
  336. long_prompt_history_radio: "long_prompt_history_radio",
  337. }
  338. output_dict = {
  339. "audio_out": audio,
  340. "npz_path": full_generation,
  341. "metadata": metadata,
  342. "folder_root": folder_root,
  343. }
  344. gr.Button("Generate", variant="primary").click(
  345. **randomize_seed_callback,
  346. ).then(
  347. **dictionarize(
  348. fn=bark_generate_long,
  349. inputs=input_dict,
  350. outputs=output_dict,
  351. ),
  352. api_name="bark",
  353. )
  354. if __name__ == "__main__":
  355. if "demo" in locals():
  356. locals()["demo"].close()
  357. with gr.Blocks() as demo:
  358. bark_tab()
  359. demo.launch(
  360. server_port=7770,
  361. )