main.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from typing import Optional
  2. import gradio as gr
  3. def extension__tts_generation_webui():
  4. scan_cache_ui()
  5. return {
  6. "package_name": "extension_huggingface_cache_manager",
  7. "name": "Huggingface Cache Manager",
  8. "version": "0.0.1",
  9. "requirements": "git+https://github.com/rsxdalv/extension_huggingface_cache_manager@main",
  10. "description": "Huggingface Cache Manager allows managing the Huggingface cache.",
  11. "extension_type": "interface",
  12. "extension_class": "tools",
  13. "author": "rsxdalv",
  14. "extension_author": "rsxdalv",
  15. "license": "MIT",
  16. "website": "https://github.com/rsxdalv/extension_huggingface_cache_manager",
  17. "extension_website": "https://github.com/rsxdalv/extension_huggingface_cache_manager",
  18. "extension_platform_version": "0.0.1",
  19. }
  20. from huggingface_hub import scan_cache_dir, HFCacheInfo
  21. from extensions.builtin.extension_huggingface_cache_manager.scan_cache import (
  22. get_headers_quiet,
  23. get_rows_quiet,
  24. get_headers_verbose,
  25. get_rows_verbose,
  26. get_headers_json,
  27. get_rows_json,
  28. render_as_markdown,
  29. )
  30. hf_cache_info: Optional[HFCacheInfo] = None
  31. def scan_cache():
  32. global hf_cache_info
  33. hf_cache_info = scan_cache_dir()
  34. table = render_as_markdown(
  35. # get_rows_quiet(hf_cache_info), get_headers_quiet()
  36. get_rows_verbose(hf_cache_info), get_headers_verbose()
  37. )
  38. return table
  39. def scan_cache_json():
  40. global hf_cache_info
  41. hf_cache_info = scan_cache_dir()
  42. import json
  43. headers = get_headers_json()
  44. data = {
  45. "headers": headers,
  46. "rows": [
  47. {
  48. header: row[i]
  49. for i, header in enumerate(headers)
  50. }
  51. for row in get_rows_json(hf_cache_info)
  52. ]
  53. }
  54. return json.dumps(data)
  55. def delete_revisions(revision_id):
  56. global hf_cache_info
  57. if hf_cache_info is None:
  58. hf_cache_info = scan_cache_dir()
  59. strategy = hf_cache_info.delete_revisions(revision_id)
  60. strategy.execute()
  61. def scan_cache_ui():
  62. gr.Markdown("Scan the Huggingface cache directory and print the results.")
  63. scan_cache_button = gr.Button("Scan cache", variant="primary")
  64. cache_table = gr.Markdown("Press Scan cache to load the list")
  65. scan_cache_button.click(
  66. fn=scan_cache,
  67. outputs=[cache_table],
  68. api_name="scan_huggingface_cache",
  69. )
  70. scan_cache_json_api = gr.JSON(visible=False)
  71. scan_cache_button_api = gr.Button("API_SCAN_CACHE", visible=False)
  72. scan_cache_button_api.click(
  73. fn=scan_cache_json,
  74. outputs=[scan_cache_json_api],
  75. api_name="scan_huggingface_cache_api",
  76. )
  77. gr.Markdown("Delete revisions")
  78. delete_revision_id = gr.Dropdown(
  79. label="Revision ID",
  80. choices=[""],
  81. value="",
  82. show_label=True,
  83. interactive=True,
  84. )
  85. refresh_revision_id_button = gr.Button("Refresh", variant="secondary")
  86. def refresh_revision_id_button_fn():
  87. global hf_cache_info
  88. if hf_cache_info is None:
  89. hf_cache_info = scan_cache_dir()
  90. revision_ids = [
  91. revision.commit_hash
  92. for repo in hf_cache_info.repos
  93. for revision in repo.revisions
  94. ]
  95. return gr.Dropdown(choices=revision_ids)
  96. refresh_revision_id_button.click(
  97. fn=refresh_revision_id_button_fn,
  98. outputs=[delete_revision_id],
  99. api_name="refresh_huggingface_cache_revisions",
  100. )
  101. delete_button = gr.Button("Delete", variant="stop")
  102. delete_button.click(
  103. fn=delete_revisions,
  104. # inputs=[cache_table],
  105. inputs=[delete_revision_id],
  106. api_name="delete_huggingface_cache_revisions",
  107. )
  108. if __name__ == "__main__":
  109. if "demo" in locals():
  110. demo.close()
  111. with gr.Blocks() as demo:
  112. with gr.Tab("Scan Cache"):
  113. scan_cache_ui()
  114. demo.queue().launch(
  115. server_port=7770,
  116. )