2-get-hubert-wav32k.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # -*- coding: utf-8 -*-
  2. import sys,os
  3. inp_text= os.environ.get("inp_text")
  4. inp_wav_dir= os.environ.get("inp_wav_dir")
  5. exp_name= os.environ.get("exp_name")
  6. i_part= os.environ.get("i_part")
  7. all_parts= os.environ.get("all_parts")
  8. os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
  9. from feature_extractor import cnhubert
  10. opt_dir= os.environ.get("opt_dir")
  11. cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
  12. is_half=eval(os.environ.get("is_half","True"))
  13. import pdb,traceback,numpy as np,logging
  14. from scipy.io import wavfile
  15. import librosa,torch
  16. now_dir = os.getcwd()
  17. sys.path.append(now_dir)
  18. from my_utils import load_audio
  19. # from config import cnhubert_base_path
  20. # cnhubert.cnhubert_base_path=cnhubert_base_path
  21. # inp_text=sys.argv[1]
  22. # inp_wav_dir=sys.argv[2]
  23. # exp_name=sys.argv[3]
  24. # i_part=sys.argv[4]
  25. # all_parts=sys.argv[5]
  26. # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
  27. # cnhubert.cnhubert_base_path=sys.argv[7]
  28. # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
  29. from time import time as ttime
  30. import shutil
  31. def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
  32. dir=os.path.dirname(path)
  33. name=os.path.basename(path)
  34. # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
  35. tmp_path="%s%s.pth"%(ttime(),i_part)
  36. torch.save(fea,tmp_path)
  37. shutil.move(tmp_path,"%s/%s"%(dir,name))
  38. hubert_dir="%s/4-cnhubert"%(opt_dir)
  39. wav32dir="%s/5-wav32k"%(opt_dir)
  40. os.makedirs(opt_dir,exist_ok=True)
  41. os.makedirs(hubert_dir,exist_ok=True)
  42. os.makedirs(wav32dir,exist_ok=True)
  43. maxx=0.95
  44. alpha=0.5
  45. if torch.cuda.is_available():
  46. device = "cuda:0"
  47. elif torch.backends.mps.is_available():
  48. device = "mps"
  49. else:
  50. device = "cpu"
  51. model=cnhubert.get_model()
  52. # is_half=False
  53. if(is_half==True):
  54. model=model.half().to(device)
  55. else:
  56. model = model.to(device)
  57. nan_fails=[]
  58. def name2go(wav_name,wav_path):
  59. hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
  60. if(os.path.exists(hubert_path)):return
  61. tmp_audio = load_audio(wav_path, 32000)
  62. tmp_max = np.abs(tmp_audio).max()
  63. if tmp_max > 2.2:
  64. print("%s-filtered,%s" % (wav_name, tmp_max))
  65. return
  66. tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
  67. tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
  68. tmp_audio = librosa.resample(
  69. tmp_audio32b, orig_sr=32000, target_sr=16000
  70. )#不是重采样问题
  71. tensor_wav16 = torch.from_numpy(tmp_audio)
  72. if (is_half == True):
  73. tensor_wav16=tensor_wav16.half().to(device)
  74. else:
  75. tensor_wav16 = tensor_wav16.to(device)
  76. ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
  77. if np.isnan(ssl.detach().numpy()).sum()!= 0:
  78. nan_fails.append(wav_name)
  79. print("nan filtered:%s"%wav_name)
  80. return
  81. wavfile.write(
  82. "%s/%s"%(wav32dir,wav_name),
  83. 32000,
  84. tmp_audio32.astype("int16"),
  85. )
  86. my_save(ssl,hubert_path )
  87. with open(inp_text,"r",encoding="utf8")as f:
  88. lines=f.read().strip("\n").split("\n")
  89. for line in lines[int(i_part)::int(all_parts)]:
  90. try:
  91. # wav_name,text=line.split("\t")
  92. wav_name, spk_name, language, text = line.split("|")
  93. if (inp_wav_dir !=None):
  94. wav_name = os.path.basename(wav_name)
  95. wav_path = "%s/%s"%(inp_wav_dir, wav_name)
  96. else:
  97. wav_path=wav_name
  98. wav_name = os.path.basename(wav_name)
  99. name2go(wav_name,wav_path)
  100. except:
  101. print(line,traceback.format_exc())
  102. if(len(nan_fails)>0 and is_half==True):
  103. is_half=False
  104. model=model.float()
  105. for wav_name in nan_fails:
  106. try:
  107. name2go(wav_name)
  108. except:
  109. print(wav_name,traceback.format_exc())