npz_tools.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import glob
  2. from typing import Any
  3. import numpy as np
  4. from tts_webui.bark.FullGeneration import FullGeneration
  5. import json
  6. import torch
  7. def compress_history(full_generation: FullGeneration):
  8. return {
  9. "semantic_prompt": full_generation["semantic_prompt"].astype(np.int16),
  10. "coarse_prompt": full_generation["coarse_prompt"].astype(np.int16),
  11. "fine_prompt": full_generation["fine_prompt"].astype(np.int16),
  12. }
  13. def pack_metadata(metadata: dict[str, Any]):
  14. # return list(json.dumps(metadata))
  15. def default(o):
  16. if isinstance(o, np.ndarray):
  17. return o.tolist()
  18. return o.__dict__
  19. return np.array(json.dumps(metadata, default=default))
  20. def save_npz(filename: str, full_generation: FullGeneration, metadata: dict[str, Any]):
  21. np.savez(
  22. filename,
  23. **{
  24. **compress_history(full_generation),
  25. "metadata": pack_metadata(metadata),
  26. },
  27. )
  28. def save_npz_musicgen(filename: str, tokens: torch.Tensor, metadata: dict[str, Any]):
  29. np.savez(
  30. filename,
  31. **{
  32. "tokens": tokens.cpu().numpy(),
  33. "metadata": pack_metadata(metadata),
  34. },
  35. )
  36. def load_npz(filename):
  37. def unpack_metadata(metadata: np.ndarray):
  38. def join_list(x: list | np.ndarray):
  39. if isinstance(x, np.ndarray):
  40. x = x.tolist()
  41. return "".join(x)
  42. return json.loads(join_list(metadata))
  43. with np.load(filename, allow_pickle=True) as data:
  44. result = {key: data[key] for key in data}
  45. if "metadata" in result:
  46. result["metadata"] = unpack_metadata(result["metadata"])
  47. return result
  48. def get_npz_files():
  49. return (
  50. glob.glob("voices/*.npz")
  51. + glob.glob("favorites/*/*.npz")
  52. + glob.glob("outputs/*/*.npz")
  53. )
  54. if __name__ == "__main__":
  55. in_npz = load_npz("./temp/ogg-vs-npz/audio__bark__None__2023-05-29_10-12-46.npz")
  56. metadata_in = {
  57. "_version": "0.0.1",
  58. "_hash_version": "0.0.2",
  59. "_type": "bark",
  60. "is_big_semantic_model": True,
  61. "is_big_coarse_model": False,
  62. "is_big_fine_model": False,
  63. "prompt": "test",
  64. "language": None,
  65. "speaker_id": None,
  66. "hash": "98b14851692f09df5e89c68f0a8e2013",
  67. "history_prompt": "continued_generation",
  68. "history_prompt_npz": None,
  69. "history_hash": "98b14851692f09df5e89c68f0a8e2013",
  70. "text_temp": 0.7,
  71. "waveform_temp": 0.7,
  72. "date": "2023-06-07_16-56-09",
  73. "seed": "2039063546",
  74. }
  75. save_npz(
  76. "./npz_reencode_test_new_list.npz",
  77. {
  78. "semantic_prompt": in_npz["semantic_prompt"],
  79. "coarse_prompt": in_npz["coarse_prompt"],
  80. "fine_prompt": in_npz["fine_prompt"],
  81. },
  82. metadata_in,
  83. )
  84. out_npz = load_npz("./npz_reencode_test_new_list.npz")
  85. assert out_npz["metadata"] == metadata_in