list_dir_models.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import os
  2. import gradio as gr
  3. from tts_webui.tortoise.gr_reload_button import gr_open_button_simple, gr_reload_button
  4. from tts_webui.utils.get_path_from_root import get_path_from_root
  5. from tts_webui.utils.manage_model_state import unload_model
  6. def list_dir_models(abs_dir: str):
  7. try:
  8. # return [x for x in os.listdir(abs_dir) if x not in [".gitkeep", "cache"]]
  9. return [x for x in next(os.walk(abs_dir))[1] if x not in ["cache"]]
  10. except FileNotFoundError as e:
  11. print(e)
  12. return []
  13. def get_models(repos, abs_dir):
  14. return repos + [(x, os.path.join(abs_dir, x)) for x in list_dir_models(abs_dir)]
  15. def model_select_ui(
  16. repos,
  17. prefix: str,
  18. Component: type[gr.Radio | gr.Dropdown] = gr.Radio,
  19. ):
  20. abs_dir = get_path_from_root("data", "models", prefix)
  21. models = get_models(repos, abs_dir)
  22. model = Component(
  23. choices=models,
  24. label="Model",
  25. value=models[0][1],
  26. )
  27. gr_reload_button().click(
  28. fn=lambda: Component(choices=get_models(repos, abs_dir)),
  29. outputs=[model],
  30. api_name=f"{prefix}_get_models",
  31. )
  32. gr_open_button_simple(abs_dir, api_name=f"{prefix}_open_model_dir")
  33. return model
  34. def unload_model_button(prefix: str):
  35. return gr.Button(value="Unload Model").click(
  36. fn=lambda: unload_model(model_namespace=prefix),
  37. api_name=f"{prefix}_unload_model",
  38. )