utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import os
  2. import glob
  3. import sys
  4. import argparse
  5. import logging
  6. import json
  7. import subprocess
  8. import traceback
  9. import librosa
  10. import numpy as np
  11. from scipy.io.wavfile import read
  12. import torch
  13. import logging
  14. logging.getLogger("numba").setLevel(logging.ERROR)
  15. logging.getLogger("matplotlib").setLevel(logging.ERROR)
  16. MATPLOTLIB_FLAG = False
  17. logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
  18. logger = logging
  19. def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
  20. assert os.path.isfile(checkpoint_path)
  21. checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
  22. iteration = checkpoint_dict["iteration"]
  23. learning_rate = checkpoint_dict["learning_rate"]
  24. if (
  25. optimizer is not None
  26. and not skip_optimizer
  27. and checkpoint_dict["optimizer"] is not None
  28. ):
  29. optimizer.load_state_dict(checkpoint_dict["optimizer"])
  30. saved_state_dict = checkpoint_dict["model"]
  31. if hasattr(model, "module"):
  32. state_dict = model.module.state_dict()
  33. else:
  34. state_dict = model.state_dict()
  35. new_state_dict = {}
  36. for k, v in state_dict.items():
  37. try:
  38. # assert "quantizer" not in k
  39. # print("load", k)
  40. new_state_dict[k] = saved_state_dict[k]
  41. assert saved_state_dict[k].shape == v.shape, (
  42. saved_state_dict[k].shape,
  43. v.shape,
  44. )
  45. except:
  46. traceback.print_exc()
  47. print(
  48. "error, %s is not in the checkpoint" % k
  49. ) # shape不对也会,比如text_embedding当cleaner修改时
  50. new_state_dict[k] = v
  51. if hasattr(model, "module"):
  52. model.module.load_state_dict(new_state_dict)
  53. else:
  54. model.load_state_dict(new_state_dict)
  55. print("load ")
  56. logger.info(
  57. "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
  58. )
  59. return model, optimizer, learning_rate, iteration
  60. from time import time as ttime
  61. import shutil
  62. def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
  63. dir=os.path.dirname(path)
  64. name=os.path.basename(path)
  65. tmp_path="%s.pth"%(ttime())
  66. torch.save(fea,tmp_path)
  67. shutil.move(tmp_path,"%s/%s"%(dir,name))
  68. def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
  69. logger.info(
  70. "Saving model and optimizer state at iteration {} to {}".format(
  71. iteration, checkpoint_path
  72. )
  73. )
  74. if hasattr(model, "module"):
  75. state_dict = model.module.state_dict()
  76. else:
  77. state_dict = model.state_dict()
  78. # torch.save(
  79. my_save(
  80. {
  81. "model": state_dict,
  82. "iteration": iteration,
  83. "optimizer": optimizer.state_dict(),
  84. "learning_rate": learning_rate,
  85. },
  86. checkpoint_path,
  87. )
  88. def summarize(
  89. writer,
  90. global_step,
  91. scalars={},
  92. histograms={},
  93. images={},
  94. audios={},
  95. audio_sampling_rate=22050,
  96. ):
  97. for k, v in scalars.items():
  98. writer.add_scalar(k, v, global_step)
  99. for k, v in histograms.items():
  100. writer.add_histogram(k, v, global_step)
  101. for k, v in images.items():
  102. writer.add_image(k, v, global_step, dataformats="HWC")
  103. for k, v in audios.items():
  104. writer.add_audio(k, v, global_step, audio_sampling_rate)
  105. def latest_checkpoint_path(dir_path, regex="G_*.pth"):
  106. f_list = glob.glob(os.path.join(dir_path, regex))
  107. f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
  108. x = f_list[-1]
  109. print(x)
  110. return x
  111. def plot_spectrogram_to_numpy(spectrogram):
  112. global MATPLOTLIB_FLAG
  113. if not MATPLOTLIB_FLAG:
  114. import matplotlib
  115. matplotlib.use("Agg")
  116. MATPLOTLIB_FLAG = True
  117. mpl_logger = logging.getLogger("matplotlib")
  118. mpl_logger.setLevel(logging.WARNING)
  119. import matplotlib.pylab as plt
  120. import numpy as np
  121. fig, ax = plt.subplots(figsize=(10, 2))
  122. im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
  123. plt.colorbar(im, ax=ax)
  124. plt.xlabel("Frames")
  125. plt.ylabel("Channels")
  126. plt.tight_layout()
  127. fig.canvas.draw()
  128. data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
  129. data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
  130. plt.close()
  131. return data
  132. def plot_alignment_to_numpy(alignment, info=None):
  133. global MATPLOTLIB_FLAG
  134. if not MATPLOTLIB_FLAG:
  135. import matplotlib
  136. matplotlib.use("Agg")
  137. MATPLOTLIB_FLAG = True
  138. mpl_logger = logging.getLogger("matplotlib")
  139. mpl_logger.setLevel(logging.WARNING)
  140. import matplotlib.pylab as plt
  141. import numpy as np
  142. fig, ax = plt.subplots(figsize=(6, 4))
  143. im = ax.imshow(
  144. alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
  145. )
  146. fig.colorbar(im, ax=ax)
  147. xlabel = "Decoder timestep"
  148. if info is not None:
  149. xlabel += "\n\n" + info
  150. plt.xlabel(xlabel)
  151. plt.ylabel("Encoder timestep")
  152. plt.tight_layout()
  153. fig.canvas.draw()
  154. data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
  155. data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
  156. plt.close()
  157. return data
  158. def load_wav_to_torch(full_path):
  159. data, sampling_rate = librosa.load(full_path, sr=None)
  160. return torch.FloatTensor(data), sampling_rate
  161. def load_filepaths_and_text(filename, split="|"):
  162. with open(filename, encoding="utf-8") as f:
  163. filepaths_and_text = [line.strip().split(split) for line in f]
  164. return filepaths_and_text
  165. def get_hparams(init=True, stage=1):
  166. parser = argparse.ArgumentParser()
  167. parser.add_argument(
  168. "-c",
  169. "--config",
  170. type=str,
  171. default="./configs/s2.json",
  172. help="JSON file for configuration",
  173. )
  174. parser.add_argument(
  175. "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
  176. )
  177. parser.add_argument(
  178. "-rs",
  179. "--resume_step",
  180. type=int,
  181. required=False,
  182. default=None,
  183. help="resume step",
  184. )
  185. # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
  186. # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
  187. # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
  188. args = parser.parse_args()
  189. config_path = args.config
  190. with open(config_path, "r") as f:
  191. data = f.read()
  192. config = json.loads(data)
  193. hparams = HParams(**config)
  194. hparams.pretrain = args.pretrain
  195. hparams.resume_step = args.resume_step
  196. # hparams.data.exp_dir = args.exp_dir
  197. if stage == 1:
  198. model_dir = hparams.s1_ckpt_dir
  199. else:
  200. model_dir = hparams.s2_ckpt_dir
  201. config_save_path = os.path.join(model_dir, "config.json")
  202. if not os.path.exists(model_dir):
  203. os.makedirs(model_dir)
  204. with open(config_save_path, "w") as f:
  205. f.write(data)
  206. return hparams
  207. def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
  208. """Freeing up space by deleting saved ckpts
  209. Arguments:
  210. path_to_models -- Path to the model directory
  211. n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
  212. sort_by_time -- True -> chronologically delete ckpts
  213. False -> lexicographically delete ckpts
  214. """
  215. import re
  216. ckpts_files = [
  217. f
  218. for f in os.listdir(path_to_models)
  219. if os.path.isfile(os.path.join(path_to_models, f))
  220. ]
  221. name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
  222. time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
  223. sort_key = time_key if sort_by_time else name_key
  224. x_sorted = lambda _x: sorted(
  225. [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
  226. key=sort_key,
  227. )
  228. to_del = [
  229. os.path.join(path_to_models, fn)
  230. for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
  231. ]
  232. del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
  233. del_routine = lambda x: [os.remove(x), del_info(x)]
  234. rs = [del_routine(fn) for fn in to_del]
  235. def get_hparams_from_dir(model_dir):
  236. config_save_path = os.path.join(model_dir, "config.json")
  237. with open(config_save_path, "r") as f:
  238. data = f.read()
  239. config = json.loads(data)
  240. hparams = HParams(**config)
  241. hparams.model_dir = model_dir
  242. return hparams
  243. def get_hparams_from_file(config_path):
  244. with open(config_path, "r") as f:
  245. data = f.read()
  246. config = json.loads(data)
  247. hparams = HParams(**config)
  248. return hparams
  249. def check_git_hash(model_dir):
  250. source_dir = os.path.dirname(os.path.realpath(__file__))
  251. if not os.path.exists(os.path.join(source_dir, ".git")):
  252. logger.warn(
  253. "{} is not a git repository, therefore hash value comparison will be ignored.".format(
  254. source_dir
  255. )
  256. )
  257. return
  258. cur_hash = subprocess.getoutput("git rev-parse HEAD")
  259. path = os.path.join(model_dir, "githash")
  260. if os.path.exists(path):
  261. saved_hash = open(path).read()
  262. if saved_hash != cur_hash:
  263. logger.warn(
  264. "git hash values are different. {}(saved) != {}(current)".format(
  265. saved_hash[:8], cur_hash[:8]
  266. )
  267. )
  268. else:
  269. open(path, "w").write(cur_hash)
  270. def get_logger(model_dir, filename="train.log"):
  271. global logger
  272. logger = logging.getLogger(os.path.basename(model_dir))
  273. logger.setLevel(logging.DEBUG)
  274. formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
  275. if not os.path.exists(model_dir):
  276. os.makedirs(model_dir)
  277. h = logging.FileHandler(os.path.join(model_dir, filename))
  278. h.setLevel(logging.DEBUG)
  279. h.setFormatter(formatter)
  280. logger.addHandler(h)
  281. return logger
  282. class HParams:
  283. def __init__(self, **kwargs):
  284. for k, v in kwargs.items():
  285. if type(v) == dict:
  286. v = HParams(**v)
  287. self[k] = v
  288. def keys(self):
  289. return self.__dict__.keys()
  290. def items(self):
  291. return self.__dict__.items()
  292. def values(self):
  293. return self.__dict__.values()
  294. def __len__(self):
  295. return len(self.__dict__)
  296. def __getitem__(self, key):
  297. return getattr(self, key)
  298. def __setitem__(self, key, value):
  299. return setattr(self, key, value)
  300. def __contains__(self, key):
  301. return key in self.__dict__
  302. def __repr__(self):
  303. return self.__dict__.__repr__()
  304. if __name__ == "__main__":
  305. print(
  306. load_wav_to_torch(
  307. "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
  308. )
  309. )