process_ckpt.py 991 B

12345678910111213141516171819202122232425262728293031
  1. import traceback
  2. from collections import OrderedDict
  3. from time import time as ttime
  4. import shutil,os
  5. import torch
  6. from tools.i18n.i18n import I18nAuto
  7. i18n = I18nAuto()
  8. def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
  9. dir=os.path.dirname(path)
  10. name=os.path.basename(path)
  11. tmp_path="%s.pth"%(ttime())
  12. torch.save(fea,tmp_path)
  13. shutil.move(tmp_path,"%s/%s"%(dir,name))
  14. def savee(ckpt, name, epoch, steps, hps):
  15. try:
  16. opt = OrderedDict()
  17. opt["weight"] = {}
  18. for key in ckpt.keys():
  19. if "enc_q" in key:
  20. continue
  21. opt["weight"][key] = ckpt[key].half()
  22. opt["config"] = hps
  23. opt["info"] = "%sepoch_%siteration" % (epoch, steps)
  24. # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
  25. my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
  26. return "Success."
  27. except:
  28. return traceback.format_exc()