3-get-semantic.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. inp_text = os.environ.get("inp_text")
  3. exp_name = os.environ.get("exp_name")
  4. i_part = os.environ.get("i_part")
  5. all_parts = os.environ.get("all_parts")
  6. os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
  7. opt_dir = os.environ.get("opt_dir")
  8. pretrained_s2G = os.environ.get("pretrained_s2G")
  9. s2config_path = os.environ.get("s2config_path")
  10. is_half = eval(os.environ.get("is_half", "True"))
  11. import math, traceback
  12. import multiprocessing
  13. import sys, pdb
  14. now_dir = os.getcwd()
  15. sys.path.append(now_dir)
  16. from random import shuffle
  17. import torch.multiprocessing as mp
  18. from glob import glob
  19. from tqdm import tqdm
  20. import logging, librosa, utils, torch
  21. from module.models import SynthesizerTrn
  22. logging.getLogger("numba").setLevel(logging.WARNING)
  23. # from config import pretrained_s2G
  24. # inp_text=sys.argv[1]
  25. # exp_name=sys.argv[2]
  26. # i_part=sys.argv[3]
  27. # all_parts=sys.argv[4]
  28. # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
  29. # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
  30. hubert_dir = "%s/4-cnhubert" % (opt_dir)
  31. semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
  32. if os.path.exists(semantic_path) == False:
  33. os.makedirs(opt_dir, exist_ok=True)
  34. if torch.cuda.is_available():
  35. device = "cuda"
  36. elif torch.backends.mps.is_available():
  37. device = "mps"
  38. else:
  39. device = "cpu"
  40. hps = utils.get_hparams_from_file(s2config_path)
  41. vq_model = SynthesizerTrn(
  42. hps.data.filter_length // 2 + 1,
  43. hps.train.segment_size // hps.data.hop_length,
  44. n_speakers=hps.data.n_speakers,
  45. **hps.model
  46. )
  47. if is_half == True:
  48. vq_model = vq_model.half().to(device)
  49. else:
  50. vq_model = vq_model.to(device)
  51. vq_model.eval()
  52. # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
  53. # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
  54. print(
  55. vq_model.load_state_dict(
  56. torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
  57. )
  58. )
  59. def name2go(wav_name, lines):
  60. hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
  61. if os.path.exists(hubert_path) == False:
  62. return
  63. ssl_content = torch.load(hubert_path, map_location="cpu")
  64. if is_half == True:
  65. ssl_content = ssl_content.half().to(device)
  66. else:
  67. ssl_content = ssl_content.to(device)
  68. codes = vq_model.extract_latent(ssl_content)
  69. semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
  70. lines.append("%s\t%s" % (wav_name, semantic))
  71. with open(inp_text, "r", encoding="utf8") as f:
  72. lines = f.read().strip("\n").split("\n")
  73. lines1 = []
  74. for line in lines[int(i_part) :: int(all_parts)]:
  75. # print(line)
  76. try:
  77. # wav_name,text=line.split("\t")
  78. wav_name, spk_name, language, text = line.split("|")
  79. wav_name = os.path.basename(wav_name)
  80. # name2go(name,lines1)
  81. name2go(wav_name, lines1)
  82. except:
  83. print(line, traceback.format_exc())
  84. with open(semantic_path, "w", encoding="utf8") as f:
  85. f.write("\n".join(lines1))