123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- from tts_webui.utils.torch_clear_memory import torch_clear_memory
- class ModelState:
- def __init__(self):
- self._model = None
- self._model_name = None
- def set_model(self, model, model_name):
- self._model = model
- self._model_name = model_name
- def get_model(self):
- return self._model
- def is_model_loaded(self, model_name):
- return self._model is not None and self._model_name == model_name
- def get_model_name(self):
- return self._model_name
- model_states = {}
- def manage_model_state(model_namespace):
- """Decorator to manage the model state."""
- def decorator(func):
- def wrapper(model_name, *args, **kwargs):
- global model_states
- if model_namespace not in model_states:
- model_states[model_namespace] = ModelState()
- model_state = model_states[model_namespace]
- if not model_state.is_model_loaded(model_name):
- print(
- f"Model '{model_name}' in namespace '{model_namespace}' is not loaded or is different. Loading model..."
- )
- unload_model(model_namespace)
- model = func(model_name, *args, **kwargs)
- model_state.set_model(model, model_name)
- else:
- print(
- f"Using cached model '{model_name}' in namespace '{model_namespace}'."
- )
- return model_state.get_model()
- return wrapper
- return decorator
- def unload_model(model_namespace):
- if (
- model_namespace in model_states
- and model_states[model_namespace].get_model() is not None
- ):
- model_states[model_namespace].set_model(None, None)
- # del model_states[model_namespace]
- torch_clear_memory()
- print(f"Model in namespace '{model_namespace}' has been unloaded.")
- else:
- print(f"No model loaded in namespace '{model_namespace}'.")
- def unload_all_models():
- for model_namespace in list(model_states.keys()):
- unload_model(model_namespace)
- def list_loaded_models_as_markdown():
- lines = ["| Model Namespace | Model Name |", "|-----------------|------------|"]
- for namespace, state in model_states.items():
- model_name = state.get_model_name()
- if model_name:
- lines.append(f"| {namespace} | {model_name} |")
- else:
- lines.append(f"| {namespace} | Not Loaded |")
- return "\n".join(lines)
|