infer_rvc.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # import os
  2. # import sys
  3. # from dotenv import load_dotenv
  4. # from tts_webui.rvc_tab.hide_argv import hide_argv
  5. # os.environ.setdefault("weight_root", "data/models/rvc/checkpoints")
  6. # os.environ.setdefault("weight_uvr5_root", "data/models/rvc/uvr5_weights")
  7. # os.environ.setdefault("index_root", "data/models/rvc/checkpoints")
  8. # os.environ.setdefault("outside_index_root", "data/models/rvc/checkpoints")
  9. # os.environ.setdefault("rmvpe_root", "data/models/rvc/rmvpe")
  10. # import rvc_pkg
  11. # rvc_dir = os.path.dirname(rvc_pkg.__file__)
  12. # sys.path.append(rvc_dir)
  13. # from rvc_pkg.configs.config import Config
  14. # from rvc_pkg.infer.modules.vc.modules import VC
  15. # sys.path.remove(rvc_dir)
  16. # from tts_webui.rvc_tab.get_and_load_hubert import get_and_load_hubert_new, download_rmvpe
  17. # last_model_path = None
  18. # vc = None
  19. # def infer_rvc(
  20. # input_path: str, # Input path
  21. # index_path_2: str, # Index path
  22. # model_name: str, # Model name (stored in assets/weight_root)
  23. # device: str, # Device
  24. # f0up_key: int = 0, # F0up key
  25. # f0method: str = "harvest", # F0 method (harvest or pm)
  26. # index_rate: float = 0.66, # Index rate
  27. # is_half: bool = False, # Use half -> True
  28. # filter_radius: int = 3, # Filter radius
  29. # resample_sr: int = 0, # Resample sample rate
  30. # rms_mix_rate: float = 1, # RMS mix rate
  31. # protect: float = 0.33, # Protect breath sounds
  32. # ):
  33. # global last_model_path
  34. # load_dotenv()
  35. # with hide_argv():
  36. # config = Config()
  37. # config.device = device if device else config.device
  38. # config.is_half = is_half if is_half else config.is_half
  39. # global vc
  40. # if vc is None:
  41. # vc = VC(config)
  42. # if vc.hubert_model is None:
  43. # vc.hubert_model = get_and_load_hubert_new(config)
  44. # if last_model_path != model_name:
  45. # vc.get_vc(model_name)
  46. # last_model_path = model_name
  47. # if f0method == "rmvpe":
  48. # download_rmvpe()
  49. # message, wav_opt = vc.vc_single(
  50. # 0,
  51. # input_path,
  52. # f0up_key,
  53. # None,
  54. # f0method,
  55. # None,
  56. # os.path.join("data\\models\\rvc\\checkpoints\\", index_path_2),
  57. # index_rate,
  58. # filter_radius,
  59. # resample_sr,
  60. # rms_mix_rate,
  61. # protect,
  62. # )
  63. # print(message)
  64. # return wav_opt, {
  65. # "original_audio_path": input_path,
  66. # "index_path": index_path_2,
  67. # "model_path": model_name,
  68. # "f0method": f0method,
  69. # "f0up_key": f0up_key,
  70. # "index_rate": index_rate,
  71. # "device": device,
  72. # "is_half": is_half,
  73. # "filter_radius": filter_radius,
  74. # "resample_sr": resample_sr,
  75. # "rms_mix_rate": rms_mix_rate,
  76. # "protect": protect,
  77. # }
  78. # if __name__ == "__main__":
  79. # rate, wav_data = infer_rvc(
  80. # input_path="sample.wav",
  81. # model_name="Voltina.pth",
  82. # index_path_2="Voltina.index",
  83. # device="cuda:0",
  84. # f0up_key=0,
  85. # f0method="harvest",
  86. # index_rate=0.66,
  87. # is_half=False,
  88. # filter_radius=3,
  89. # resample_sr=0,
  90. # rms_mix_rate=1,
  91. # protect=0.33,
  92. # )
  93. # from scipy.io.wavfile import write
  94. # write("out.wav", rate, wav_data)