main.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import gradio as gr
  2. import os
  3. def extension__tts_generation_webui():
  4. model_download_ui()
  5. return {
  6. "package_name": "extension_model_downloader",
  7. "name": "Model Downloader",
  8. "version": "0.0.1",
  9. "requirements": "git+https://github.com/rsxdalv/extension_model_downloader@main",
  10. "description": "Model Downloader allows downloading models from the Huggingface model hub.",
  11. "extension_type": "interface",
  12. "extension_class": "tools",
  13. "author": "rsxdalv",
  14. "extension_author": "rsxdalv",
  15. "license": "MIT",
  16. "website": "https://github.com/rsxdalv/extension_model_downloader",
  17. "extension_website": "https://github.com/rsxdalv/extension_model_downloader",
  18. "extension_platform_version": "0.0.1",
  19. }
  20. _model_base_dir = os.path.join("data", "models")
  21. def download_pretrained_model(model_type: str, name: str, token: str):
  22. from huggingface_hub import snapshot_download
  23. local_dir = os.path.join(
  24. _model_base_dir,
  25. model_type,
  26. name.replace("/", "__").replace(":", "_").replace(".", "_"),
  27. )
  28. model_config_path = snapshot_download(
  29. name,
  30. repo_type="model",
  31. local_dir=local_dir,
  32. local_dir_use_symlinks=False,
  33. token=token,
  34. )
  35. print(model_config_path)
  36. return "ok"
  37. def model_download_ui():
  38. gr.Markdown(
  39. "Models can be found on the [HuggingFace model hub](https://huggingface.co/models?search=whisper)."
  40. )
  41. model_type = gr.Textbox(
  42. label="Model type, e.g. magnet, maha-tts, musicgen_audiogen, parler_tts, rvc, stable-audio, tortoise, vall-e-x, whisper, xtts",
  43. value="",
  44. )
  45. pretrained_name_text = gr.Textbox(
  46. label="HuggingFace repo name, e.g. openai/whisper-small",
  47. value="",
  48. )
  49. token_text = gr.Textbox(
  50. label="HuggingFace Token (Optional, but needed for some non-public models)",
  51. placeholder="hf_nFjKuKLJF...",
  52. value="",
  53. )
  54. download_btn = gr.Button("Download")
  55. download_btn.click(
  56. download_pretrained_model,
  57. inputs=[model_type, pretrained_name_text, token_text],
  58. outputs=[pretrained_name_text],
  59. api_name="model_download",
  60. )
  61. if __name__ == "__main__":
  62. if "demo" in locals():
  63. locals()["demo"].close()
  64. with gr.Blocks() as demo:
  65. model_download_ui()
  66. demo.queue().launch()