2
0

manage_model_state.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from tts_webui.utils.torch_clear_memory import torch_clear_memory
  2. class ModelState:
  3. def __init__(self):
  4. self._model = None
  5. self._model_name = None
  6. def set_model(self, model, model_name):
  7. self._model = model
  8. self._model_name = model_name
  9. def get_model(self):
  10. return self._model
  11. def is_model_loaded(self, model_name):
  12. return self._model is not None and self._model_name == model_name
  13. def get_model_name(self):
  14. return self._model_name
  15. model_states = {}
  16. def manage_model_state(model_namespace):
  17. """Decorator to manage the model state."""
  18. def decorator(func):
  19. def wrapper(model_name, *args, **kwargs):
  20. global model_states
  21. if model_namespace not in model_states:
  22. model_states[model_namespace] = ModelState()
  23. model_state = model_states[model_namespace]
  24. if not model_state.is_model_loaded(model_name):
  25. print(
  26. f"Model '{model_name}' in namespace '{model_namespace}' is not loaded or is different. Loading model..."
  27. )
  28. unload_model(model_namespace)
  29. model = func(model_name, *args, **kwargs)
  30. model_state.set_model(model, model_name)
  31. else:
  32. print(
  33. f"Using cached model '{model_name}' in namespace '{model_namespace}'."
  34. )
  35. return model_state.get_model()
  36. return wrapper
  37. return decorator
  38. def unload_model(model_namespace):
  39. if (
  40. model_namespace in model_states
  41. and model_states[model_namespace].get_model() is not None
  42. ):
  43. model_states[model_namespace].set_model(None, None)
  44. # del model_states[model_namespace]
  45. torch_clear_memory()
  46. print(f"Model in namespace '{model_namespace}' has been unloaded.")
  47. else:
  48. print(f"No model loaded in namespace '{model_namespace}'.")
  49. def unload_all_models():
  50. for model_namespace in list(model_states.keys()):
  51. unload_model(model_namespace)
  52. def list_loaded_models_as_markdown():
  53. lines = ["| Model Namespace | Model Name |", "|-----------------|------------|"]
  54. for namespace, state in model_states.items():
  55. model_name = state.get_model_name()
  56. if model_name:
  57. lines.append(f"| {namespace} | {model_name} |")
  58. else:
  59. lines.append(f"| {namespace} | Not Loaded |")
  60. return "\n".join(lines)