123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- # -*- coding: utf-8 -*-
- import sys,os
- inp_text= os.environ.get("inp_text")
- inp_wav_dir= os.environ.get("inp_wav_dir")
- exp_name= os.environ.get("exp_name")
- i_part= os.environ.get("i_part")
- all_parts= os.environ.get("all_parts")
- os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
- from feature_extractor import cnhubert
- opt_dir= os.environ.get("opt_dir")
- cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
- is_half=eval(os.environ.get("is_half","True"))
- import pdb,traceback,numpy as np,logging
- from scipy.io import wavfile
- import librosa,torch
- now_dir = os.getcwd()
- sys.path.append(now_dir)
- from my_utils import load_audio
- # from config import cnhubert_base_path
- # cnhubert.cnhubert_base_path=cnhubert_base_path
- # inp_text=sys.argv[1]
- # inp_wav_dir=sys.argv[2]
- # exp_name=sys.argv[3]
- # i_part=sys.argv[4]
- # all_parts=sys.argv[5]
- # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
- # cnhubert.cnhubert_base_path=sys.argv[7]
- # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
- from time import time as ttime
- import shutil
- def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
- dir=os.path.dirname(path)
- name=os.path.basename(path)
- # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
- tmp_path="%s%s.pth"%(ttime(),i_part)
- torch.save(fea,tmp_path)
- shutil.move(tmp_path,"%s/%s"%(dir,name))
- hubert_dir="%s/4-cnhubert"%(opt_dir)
- wav32dir="%s/5-wav32k"%(opt_dir)
- os.makedirs(opt_dir,exist_ok=True)
- os.makedirs(hubert_dir,exist_ok=True)
- os.makedirs(wav32dir,exist_ok=True)
- maxx=0.95
- alpha=0.5
- if torch.cuda.is_available():
- device = "cuda:0"
- elif torch.backends.mps.is_available():
- device = "mps"
- else:
- device = "cpu"
- model=cnhubert.get_model()
- # is_half=False
- if(is_half==True):
- model=model.half().to(device)
- else:
- model = model.to(device)
- nan_fails=[]
- def name2go(wav_name,wav_path):
- hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
- if(os.path.exists(hubert_path)):return
- tmp_audio = load_audio(wav_path, 32000)
- tmp_max = np.abs(tmp_audio).max()
- if tmp_max > 2.2:
- print("%s-filtered,%s" % (wav_name, tmp_max))
- return
- tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
- tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
- tmp_audio = librosa.resample(
- tmp_audio32b, orig_sr=32000, target_sr=16000
- )#不是重采样问题
- tensor_wav16 = torch.from_numpy(tmp_audio)
- if (is_half == True):
- tensor_wav16=tensor_wav16.half().to(device)
- else:
- tensor_wav16 = tensor_wav16.to(device)
- ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
- if np.isnan(ssl.detach().numpy()).sum()!= 0:
- nan_fails.append(wav_name)
- print("nan filtered:%s"%wav_name)
- return
- wavfile.write(
- "%s/%s"%(wav32dir,wav_name),
- 32000,
- tmp_audio32.astype("int16"),
- )
- my_save(ssl,hubert_path )
- with open(inp_text,"r",encoding="utf8")as f:
- lines=f.read().strip("\n").split("\n")
- for line in lines[int(i_part)::int(all_parts)]:
- try:
- # wav_name,text=line.split("\t")
- wav_name, spk_name, language, text = line.split("|")
- if (inp_wav_dir !=None):
- wav_name = os.path.basename(wav_name)
- wav_path = "%s/%s"%(inp_wav_dir, wav_name)
- else:
- wav_path=wav_name
- wav_name = os.path.basename(wav_name)
- name2go(wav_name,wav_path)
- except:
- print(line,traceback.format_exc())
- if(len(nan_fails)>0 and is_half==True):
- is_half=False
- model=model.float()
- for wav_name in nan_fails:
- try:
- name2go(wav_name)
- except:
- print(wav_name,traceback.format_exc())
|