generate_and_save_metadata.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from tts_webui.bark.FullGeneration import FullGeneration
  2. from tts_webui.bark.BarkParams import BarkParams
  3. from bark.generation import models
  4. from tts_webui.history_tab.get_hash_memoized import get_hash_memoized
  5. def _is_big_model(model):
  6. return model.config.n_embd > 768
  7. def generate_bark_metadata(
  8. date: str,
  9. full_generation: FullGeneration,
  10. params: BarkParams,
  11. ):
  12. history_prompt = params["history_prompt"]
  13. metadata = {
  14. "_version": "0.0.3",
  15. "_hash_version": "0.0.2",
  16. "_type": params["_type"],
  17. "text": params["text"],
  18. "text_temp": params["text_temp"],
  19. "seed": params["seed"],
  20. "max_length": params["max_length"],
  21. "waveform_temp": params["waveform_temp"],
  22. "is_big_semantic_model": _is_big_model(models["text"]["model"]),
  23. "is_big_coarse_model": _is_big_model(models["coarse"]),
  24. "is_big_fine_model": _is_big_model(models["fine"]),
  25. "date": date,
  26. "hash": get_hash_memoized(full_generation),
  27. "history_prompt": history_prompt if isinstance(history_prompt, str) else None,
  28. "history_prompt_npz": (
  29. history_prompt if isinstance(history_prompt, str) else None
  30. ),
  31. "history_hash": get_hash_memoized(history_prompt),
  32. }
  33. return metadata