1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import os
- import gradio as gr
- from tts_webui.tortoise.gr_reload_button import gr_open_button_simple, gr_reload_button
- from tts_webui.utils.get_path_from_root import get_path_from_root
- from tts_webui.utils.manage_model_state import unload_model
- def list_dir_models(abs_dir: str):
- try:
- # return [x for x in os.listdir(abs_dir) if x not in [".gitkeep", "cache"]]
- return [x for x in next(os.walk(abs_dir))[1] if x not in ["cache"]]
- except FileNotFoundError as e:
- print(e)
- return []
- def get_models(repos, abs_dir):
- return repos + [(x, os.path.join(abs_dir, x)) for x in list_dir_models(abs_dir)]
- def model_select_ui(
- repos,
- prefix: str,
- Component: type[gr.Radio | gr.Dropdown] = gr.Radio,
- ):
- abs_dir = get_path_from_root("data", "models", prefix)
- models = get_models(repos, abs_dir)
- model = Component(
- choices=models,
- label="Model",
- value=models[0][1],
- )
- gr_reload_button().click(
- fn=lambda: Component(choices=get_models(repos, abs_dir)),
- outputs=[model],
- api_name=f"{prefix}_get_models",
- )
- gr_open_button_simple(abs_dir, api_name=f"{prefix}_open_model_dir")
- return model
- def unload_model_button(prefix: str):
- return gr.Button(value="Unload Model").click(
- fn=lambda: unload_model(model_namespace=prefix),
- api_name=f"{prefix}_unload_model",
- )
|