decorator_extensions.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import json
  2. import importlib
  3. import importlib.util
  4. from importlib.metadata import version
  5. import time
  6. from types import ModuleType
  7. from typing import Literal
  8. import gradio as gr
  9. from tts_webui.utils.pip_install import pip_install_wrapper, pip_uninstall_wrapper
  10. from tts_webui.utils.generic_error_tab_advanced import generic_error_tab_advanced
  11. def check_if_package_installed(package_name):
  12. spec = importlib.util.find_spec(package_name)
  13. return spec is not None
  14. # A list of disabled extensions and decorators
  15. disabled_extensions = ["decorator_disabled"]
  16. def get_extension_list_json():
  17. try:
  18. return json.load(open("extensions.json"))["decorators"]
  19. except Exception as e:
  20. print("\n! Failed to load extensions.json:", e)
  21. return []
  22. extension_list_json = get_extension_list_json()
  23. def extension_decorator_list_tab():
  24. with gr.Tab("Decorator Extensions List"):
  25. gr.Markdown("List of all extensions")
  26. table_string = """| Title | Description |\n| --- | --- |\n"""
  27. for x in extension_list_json:
  28. table_string += (
  29. # f"| {x['name']} (v{x['version']}) "
  30. f"| {x['name']} "
  31. + f"| {x['description']} (website: {x['website']}) (extension_website: {x['extension_website']}) |\n"
  32. )
  33. gr.Markdown(table_string)
  34. external_extension_list = [
  35. x for x in extension_list_json if "builtin" not in x["package_name"]
  36. ]
  37. with gr.Row():
  38. with gr.Column():
  39. gr.Markdown("Install/Uninstall Extensions")
  40. install_dropdown = gr.Dropdown(
  41. label="Select Extension to Install",
  42. choices=[x["package_name"] for x in external_extension_list],
  43. )
  44. install_button = gr.Button("Install extension")
  45. def install_extension(package_name):
  46. requirements = [
  47. x["requirements"]
  48. for x in external_extension_list
  49. if x["package_name"] == package_name
  50. ][0]
  51. yield from pip_install_wrapper(requirements, package_name)()
  52. install_button.click(
  53. fn=install_extension,
  54. inputs=[install_dropdown],
  55. outputs=[gr.HTML()],
  56. api_name="install_extension",
  57. )
  58. with gr.Column():
  59. gr.Markdown("Uninstall Extensions")
  60. uninstall_dropdown = gr.Dropdown(
  61. label="Select Extension to Uninstall",
  62. choices=[x["package_name"] for x in external_extension_list],
  63. )
  64. uninstall_button = gr.Button("Uninstall extension")
  65. def uninstall_extension(package_name):
  66. yield from pip_uninstall_wrapper(package_name, package_name)()
  67. uninstall_button.click(
  68. fn=uninstall_extension,
  69. inputs=[uninstall_dropdown],
  70. outputs=[gr.HTML()],
  71. api_name="uninstall_extension",
  72. )
  73. def _load_decorators(class_name: Literal["outer", "inner"]):
  74. """
  75. Loads all decorators from extensions.
  76. The decorators are loaded from the "main" module of the extension.
  77. Decorators must be functions prefixed with "decorator_".
  78. Generators are detected by the suffix "_generator".
  79. For example:
  80. def decorator_save_ogg(fn):
  81. def wrapper(*args, **kwargs):
  82. return fn(*args, **kwargs)
  83. return wrapper
  84. def decorator_save_ogg_generator(fn):
  85. def wrapper(*args, **kwargs):
  86. yield from fn(*args, **kwargs)
  87. return wrapper
  88. Args:
  89. class_name (str): "outer" or "inner"
  90. Returns:
  91. wrappers (list): List of decorators.
  92. gen_wrappers (list): List of decorators for generators.
  93. """
  94. wrappers = []
  95. gen_wrappers = []
  96. def _parse_module(module: ModuleType, name: str):
  97. if name.startswith("decorator_"):
  98. if name in disabled_extensions:
  99. print(f" Skipping disabled decorator extension {name}")
  100. return
  101. if name.endswith("_generator"):
  102. gen_wrappers.append(getattr(module, name))
  103. print(f" Decorator {name} loaded")
  104. return
  105. wrappers.append(getattr(module, name))
  106. print(f" Decorator {name} loaded")
  107. def _load(x: dict):
  108. if x["package_name"] in disabled_extensions:
  109. print(f"Skipping disabled decorator extension {x['name']}")
  110. return
  111. module = importlib.import_module(f"{x['package_name']}.main")
  112. for name in dir(module):
  113. _parse_module(module, name)
  114. filtered_extensions = [
  115. x
  116. for x in extension_list_json
  117. if x["extension_type"] == "decorator" and x["extension_class"] == class_name
  118. ]
  119. for x in filtered_extensions:
  120. print(f"Loading decorator extension {x['name']}")
  121. start_time = time.time()
  122. try:
  123. _load(x)
  124. except Exception as e:
  125. print(f"Failed to load decorator extension {x['name']}: {e}")
  126. finally:
  127. elapsed_time = time.time() - start_time
  128. print(f" Done in {elapsed_time:.2f} seconds. ({x['name']})\n")
  129. return wrappers, gen_wrappers
  130. OUTER_WRAPPERS, OUTER_WRAPPERS_GEN = _load_decorators("outer")
  131. INNER_WRAPPERS, INNER_WRAPPERS_GEN = _load_decorators("inner")
  132. def decorator_extension_outer(fn0):
  133. return _decorator_extension(OUTER_WRAPPERS, fn0)
  134. def decorator_extension_inner(fn0):
  135. return _decorator_extension(INNER_WRAPPERS, fn0)
  136. def decorator_extension_outer_generator(fn0):
  137. return _decorator_extension(OUTER_WRAPPERS_GEN, fn0)
  138. def decorator_extension_inner_generator(fn0):
  139. return _decorator_extension(INNER_WRAPPERS_GEN, fn0)
  140. def _decorator_extension(wrappers, fn0):
  141. wrappers.reverse()
  142. for wrapper in wrappers:
  143. fn0 = wrapper(fn0)
  144. return fn0
  145. if __name__ == "__main__":
  146. pass