server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # ruff: noqa: E402
  2. # %%
  3. print("Starting server...\n")
  4. import tts_webui.utils.setup_or_recover as setup_or_recover
  5. setup_or_recover.setup_or_recover()
  6. import tts_webui.utils.dotenv_init as dotenv_init
  7. dotenv_init.init()
  8. import os
  9. import gradio as gr
  10. from tts_webui.utils.suppress_warnings import suppress_warnings
  11. suppress_warnings()
  12. from tts_webui.config.load_config import default_config
  13. from tts_webui.config.config import config
  14. from tts_webui.css.css import full_css
  15. from tts_webui.history_tab.collections_directories_atom import (
  16. collections_directories_atom,
  17. )
  18. from tts_webui.utils.generic_error_tab_advanced import generic_error_tab_advanced
  19. from tts_webui.extensions_loader.interface_extensions import (
  20. extension_list_tab,
  21. handle_extension_class,
  22. )
  23. from tts_webui.extensions_loader.decorator_extensions import (
  24. extension_decorator_list_tab,
  25. )
  26. def reload_config_and_restart_ui():
  27. os._exit(0)
  28. # print("Reloading config and restarting UI...")
  29. # config = load_config()
  30. # gradio_interface_options = config["gradio_interface_options"] if "gradio_interface_options" in config else {}
  31. # demo.close()
  32. # time.sleep(1)
  33. # demo.launch(**gradio_interface_options)
  34. gradio_interface_options = (
  35. config["gradio_interface_options"]
  36. if "gradio_interface_options" in config
  37. else default_config
  38. )
  39. import time
  40. import importlib
  41. def run_tab(module_name, function_name, name, requirements=None):
  42. print(f"Loading {name} tab...")
  43. start_time = time.time()
  44. try:
  45. module = importlib.import_module(module_name)
  46. func = getattr(module, function_name)
  47. func()
  48. except Exception as e:
  49. generic_error_tab_advanced(e, name=name, requirements=requirements)
  50. finally:
  51. elapsed_time = time.time() - start_time
  52. print(f" Done in {elapsed_time:.2f} seconds. ({name})\n")
  53. def load_tabs(list_of_tabs):
  54. for tab in list_of_tabs:
  55. module_name, function_name, name = tab[:3]
  56. requirements = tab[3] if len(tab) > 3 else None
  57. run_tab(module_name, function_name, name, requirements)
  58. def main_ui(theme_choice="Base"):
  59. themes = {
  60. "Base": gr.themes.Base,
  61. "Default": gr.themes.Default,
  62. "Monochrome": gr.themes.Monochrome,
  63. }
  64. theme: gr.themes.Base = themes[theme_choice](
  65. # primary_hue="blue",
  66. primary_hue="sky",
  67. secondary_hue="sky",
  68. neutral_hue="neutral",
  69. font=[
  70. gr.themes.GoogleFont("Inter"),
  71. "ui-sans-serif",
  72. "system-ui",
  73. "sans-serif",
  74. ],
  75. )
  76. theme.set(
  77. embed_radius="*radius_sm",
  78. block_label_radius="*radius_sm",
  79. block_label_right_radius="*radius_sm",
  80. block_radius="*radius_sm",
  81. block_title_radius="*radius_sm",
  82. container_radius="*radius_sm",
  83. checkbox_border_radius="*radius_sm",
  84. input_radius="*radius_sm",
  85. table_radius="*radius_sm",
  86. button_large_radius="*radius_sm",
  87. button_small_radius="*radius_sm",
  88. button_primary_background_fill_hover="*primary_300",
  89. button_primary_background_fill_hover_dark="*primary_600",
  90. button_secondary_background_fill_hover="*secondary_200",
  91. button_secondary_background_fill_hover_dark="*secondary_600",
  92. )
  93. with gr.Blocks(
  94. css=full_css,
  95. title="TTS Generation WebUI",
  96. analytics_enabled=False, # it broke too many times
  97. theme=theme,
  98. ) as blocks:
  99. gr.Markdown(
  100. """
  101. # TTS Generation WebUI (Legacy - Gradio) [React UI](http://localhost:3000) [Feedback / Bug reports](https://forms.gle/2L62owhBsGFzdFBC8) [Discord Server](https://discord.gg/V8BKTVRtJ9)
  102. ### _(Text To Speech, Audio & Music Generation, Conversion)_
  103. """
  104. )
  105. with gr.Tabs():
  106. all_tabs()
  107. return blocks
  108. def all_tabs():
  109. with gr.Tab("Text-to-Speech"), gr.Tabs():
  110. tts_tabs = [
  111. ("tts_webui.bark.bark_tab", "bark_tab", "Bark TTS"),
  112. (
  113. "tts_webui.bark.clone.tab_voice_clone",
  114. "tab_voice_clone",
  115. "Bark Voice Clone",
  116. "-r requirements_bark_hubert_quantizer.txt",
  117. ),
  118. (
  119. "tts_webui.tortoise.tortoise_tab",
  120. "tortoise_tab",
  121. "Tortoise TTS",
  122. ),
  123. (
  124. "tts_webui.seamlessM4T.seamless_tab",
  125. "seamless_tab",
  126. "SeamlessM4Tv2Model",
  127. ),
  128. (
  129. "tts_webui.vall_e_x.vall_e_x_tab",
  130. "valle_x_tab",
  131. "Valle-X",
  132. "-r requirements_vall_e.txt",
  133. ),
  134. ("tts_webui.mms.mms_tab", "mms_tab", "MMS"),
  135. (
  136. "tts_webui.maha_tts.maha_tts_tab",
  137. "maha_tts_tab",
  138. "MahaTTS",
  139. "-r requirements_maha_tts.txt",
  140. ),
  141. (
  142. "tts_webui.styletts2.styletts2_tab",
  143. "style_tts2_tab",
  144. "StyleTTS2",
  145. "-r requirements_styletts2.txt",
  146. ),
  147. ]
  148. load_tabs(tts_tabs)
  149. handle_extension_class("text-to-speech", config)
  150. with gr.Tab("Audio/Music Generation"), gr.Tabs():
  151. audio_music_generation_tabs = [
  152. (
  153. "tts_webui.stable_audio.stable_audio_tab",
  154. "stable_audio_tab",
  155. "Stable Audio",
  156. "-r requirements_stable_audio.txt",
  157. ),
  158. (
  159. "tts_webui.magnet.magnet_tab",
  160. "magnet_tab",
  161. "MAGNeT",
  162. "-r requirements_audiocraft.txt",
  163. ),
  164. (
  165. "tts_webui.musicgen.musicgen_tab",
  166. "musicgen_tab",
  167. "MusicGen",
  168. "-r requirements_audiocraft.txt",
  169. ),
  170. ]
  171. load_tabs(audio_music_generation_tabs)
  172. handle_extension_class("audio-music-generation", config)
  173. with gr.Tab("Audio Conversion"), gr.Tabs():
  174. audio_conversion_tabs = [
  175. (
  176. "tts_webui.rvc_tab.rvc_tab",
  177. "rvc_conversion_tab",
  178. "RVC",
  179. "-r requirements_rvc.txt",
  180. ),
  181. (
  182. "tts_webui.rvc_tab.uvr5_tab",
  183. "uvr5_tab",
  184. "UVR5",
  185. "-r requirements_rvc.txt",
  186. ),
  187. (
  188. "tts_webui.demucs.demucs_tab",
  189. "demucs_tab",
  190. "Demucs",
  191. "-r requirements_audiocraft.txt",
  192. ),
  193. ("tts_webui.vocos.vocos_tabs", "vocos_tabs", "Vocos"),
  194. ]
  195. load_tabs(audio_conversion_tabs)
  196. handle_extension_class("audio-conversion", config)
  197. with gr.Tab("Outputs"), gr.Tabs():
  198. from tts_webui.history_tab.main import history_tab
  199. collections_directories_atom.render()
  200. try:
  201. history_tab()
  202. history_tab(directory="favorites")
  203. history_tab(
  204. directory="outputs",
  205. show_collections=True,
  206. )
  207. except Exception as e:
  208. generic_error_tab_advanced(e, name="History", requirements=None)
  209. outputs_tabs = [
  210. # voices
  211. # ("tts_webui.history_tab.voices_tab", "voices_tab", "Voices"),
  212. ]
  213. load_tabs(outputs_tabs)
  214. handle_extension_class("outputs", config)
  215. with gr.Tab("Tools"), gr.Tabs():
  216. tools_tabs = []
  217. load_tabs(tools_tabs)
  218. handle_extension_class("tools", config)
  219. with gr.Tab("Settings"), gr.Tabs():
  220. from tts_webui.settings_tab_gradio import settings_tab_gradio
  221. settings_tab_gradio(reload_config_and_restart_ui, gradio_interface_options)
  222. settings_tabs = [
  223. # (
  224. # "tts_webui.bark.settings_tab_bark",
  225. # "settings_tab_bark",
  226. # "Settings (Bark)",
  227. # ),
  228. (
  229. "tts_webui.utils.model_location_settings_tab",
  230. "model_location_settings_tab",
  231. "Model Location Settings",
  232. ),
  233. ("tts_webui.utils.gpu_info_tab", "gpu_info_tab", "GPU Info"),
  234. ("tts_webui.utils.pip_list_tab", "pip_list_tab", "Installed Packages"),
  235. ]
  236. load_tabs(settings_tabs)
  237. extension_list_tab()
  238. extension_decorator_list_tab()
  239. handle_extension_class("settings", config)
  240. def start_gradio_server():
  241. def print_pretty_options(options):
  242. print(" Gradio interface options:")
  243. max_key_length = max(len(key) for key in options.keys())
  244. for key, value in options.items():
  245. if key == "auth" and value is not None:
  246. print(f" {key}:{' ' * (max_key_length - len(key))} {value[0]}:******")
  247. else:
  248. print(f" {key}:{' ' * (max_key_length - len(key))} {value}")
  249. # detect if --share is passed
  250. if "--share" in os.sys.argv:
  251. print("Gradio share mode enabled")
  252. gradio_interface_options["share"] = True
  253. if "--docker" in os.sys.argv:
  254. gradio_interface_options["server_name"] = "0.0.0.0"
  255. print("Info: Docker mode: opening gradio server on all interfaces")
  256. print("Starting Gradio server...")
  257. if "enable_queue" in gradio_interface_options:
  258. del gradio_interface_options["enable_queue"]
  259. if gradio_interface_options["auth"] is not None:
  260. # split username:password into (username, password)
  261. gradio_interface_options["auth"] = tuple(
  262. gradio_interface_options["auth"].split(":")
  263. )
  264. print("Gradio server authentication enabled")
  265. # delete show_tips option
  266. if "show_tips" in gradio_interface_options:
  267. del gradio_interface_options["show_tips"]
  268. # TypeError: Blocks.launch() got an unexpected keyword argument 'file_directories'
  269. if "file_directories" in gradio_interface_options:
  270. del gradio_interface_options["file_directories"]
  271. print_pretty_options(gradio_interface_options)
  272. demo = main_ui()
  273. print("\n\n")
  274. if gradio_interface_options["server_name"] == "0.0.0.0":
  275. print("Notice: Server is open to the internet")
  276. print(
  277. f"Gradio server will be available on http://localhost:{gradio_interface_options['server_port']}"
  278. )
  279. # concurrency_count=gradio_interface_options.get("concurrency_count", 5),
  280. demo.queue().launch(**gradio_interface_options, allowed_paths=["."])
  281. def server_hypervisor():
  282. import subprocess
  283. import signal
  284. import sys
  285. postgres_dir = os.path.join("data", "postgres")
  286. def stop_postgres(postgres_process):
  287. try:
  288. subprocess.check_call(f"pg_ctl stop -D {postgres_dir} -m fast", shell=True)
  289. print("PostgreSQL stopped gracefully.")
  290. except Exception as e:
  291. print(f"Error stopping PostgreSQL: {e}")
  292. postgres_process.terminate()
  293. def signal_handler(signal, frame, postgres_process):
  294. print("Shutting down...")
  295. stop_postgres(postgres_process)
  296. sys.exit(0)
  297. print("Starting React UI...")
  298. subprocess.Popen(
  299. "npm start --prefix react-ui",
  300. env={
  301. **os.environ,
  302. "GRADIO_BACKEND_AUTOMATIC": f"http://127.0.0.1:{gradio_interface_options['server_port']}",
  303. },
  304. shell=True,
  305. )
  306. if "--docker" in os.sys.argv:
  307. print("Info: Docker mode: skipping Postgres")
  308. return
  309. print("Starting Postgres...")
  310. postgres_process = subprocess.Popen(f"postgres -D {postgres_dir} -p 7773", shell=True)
  311. try:
  312. signal.signal(
  313. signal.SIGINT,
  314. lambda sig, frame: signal_handler(sig, frame, postgres_process),
  315. ) # Ctrl+C
  316. signal.signal(
  317. signal.SIGTERM,
  318. lambda sig, frame: signal_handler(sig, frame, postgres_process),
  319. ) # Termination signals
  320. if os.name != "nt":
  321. signal.signal(
  322. signal.SIGHUP,
  323. lambda sig, frame: signal_handler(sig, frame, postgres_process),
  324. ) # Terminal closure
  325. signal.signal(
  326. signal.SIGQUIT,
  327. lambda sig, frame: signal_handler(sig, frame, postgres_process),
  328. ) # Quit
  329. except (ValueError, OSError) as e:
  330. print(f"Failed to set signal handlers: {e}")
  331. if __name__ == "__main__":
  332. server_hypervisor()
  333. import webbrowser
  334. if gradio_interface_options["inbrowser"]:
  335. webbrowser.open("http://localhost:3000")
  336. start_gradio_server()