mdxnet.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import os
  2. import logging
  3. logger = logging.getLogger(__name__)
  4. import librosa
  5. import numpy as np
  6. import soundfile as sf
  7. import torch
  8. from tqdm import tqdm
  9. cpu = torch.device("cpu")
  10. class ConvTDFNetTrim:
  11. def __init__(
  12. self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024
  13. ):
  14. super(ConvTDFNetTrim, self).__init__()
  15. self.dim_f = dim_f
  16. self.dim_t = 2**dim_t
  17. self.n_fft = n_fft
  18. self.hop = hop
  19. self.n_bins = self.n_fft // 2 + 1
  20. self.chunk_size = hop * (self.dim_t - 1)
  21. self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(
  22. device
  23. )
  24. self.target_name = target_name
  25. self.blender = "blender" in model_name
  26. self.dim_c = 4
  27. out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
  28. self.freq_pad = torch.zeros(
  29. [1, out_c, self.n_bins - self.dim_f, self.dim_t]
  30. ).to(device)
  31. self.n = L // 2
  32. def stft(self, x):
  33. x = x.reshape([-1, self.chunk_size])
  34. x = torch.stft(
  35. x,
  36. n_fft=self.n_fft,
  37. hop_length=self.hop,
  38. window=self.window,
  39. center=True,
  40. return_complex=True,
  41. )
  42. x = torch.view_as_real(x)
  43. x = x.permute([0, 3, 1, 2])
  44. x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
  45. [-1, self.dim_c, self.n_bins, self.dim_t]
  46. )
  47. return x[:, :, : self.dim_f]
  48. def istft(self, x, freq_pad=None):
  49. freq_pad = (
  50. self.freq_pad.repeat([x.shape[0], 1, 1, 1])
  51. if freq_pad is None
  52. else freq_pad
  53. )
  54. x = torch.cat([x, freq_pad], -2)
  55. c = 4 * 2 if self.target_name == "*" else 2
  56. x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
  57. [-1, 2, self.n_bins, self.dim_t]
  58. )
  59. x = x.permute([0, 2, 3, 1])
  60. x = x.contiguous()
  61. x = torch.view_as_complex(x)
  62. x = torch.istft(
  63. x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
  64. )
  65. return x.reshape([-1, c, self.chunk_size])
  66. def get_models(device, dim_f, dim_t, n_fft):
  67. return ConvTDFNetTrim(
  68. device=device,
  69. model_name="Conv-TDF",
  70. target_name="vocals",
  71. L=11,
  72. dim_f=dim_f,
  73. dim_t=dim_t,
  74. n_fft=n_fft,
  75. )
  76. class Predictor:
  77. def __init__(self, args):
  78. import onnxruntime as ort
  79. logger.info(ort.get_available_providers())
  80. self.args = args
  81. self.model_ = get_models(
  82. device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
  83. )
  84. self.model = ort.InferenceSession(
  85. os.path.join(args.onnx, self.model_.target_name + ".onnx"),
  86. providers=[
  87. "CUDAExecutionProvider",
  88. "DmlExecutionProvider",
  89. "CPUExecutionProvider",
  90. ],
  91. )
  92. logger.info("ONNX load done")
  93. def demix(self, mix):
  94. samples = mix.shape[-1]
  95. margin = self.args.margin
  96. chunk_size = self.args.chunks * 44100
  97. assert not margin == 0, "margin cannot be zero!"
  98. if margin > chunk_size:
  99. margin = chunk_size
  100. segmented_mix = {}
  101. if self.args.chunks == 0 or samples < chunk_size:
  102. chunk_size = samples
  103. counter = -1
  104. for skip in range(0, samples, chunk_size):
  105. counter += 1
  106. s_margin = 0 if counter == 0 else margin
  107. end = min(skip + chunk_size + margin, samples)
  108. start = skip - s_margin
  109. segmented_mix[skip] = mix[:, start:end].copy()
  110. if end == samples:
  111. break
  112. sources = self.demix_base(segmented_mix, margin_size=margin)
  113. """
  114. mix:(2,big_sample)
  115. segmented_mix:offset->(2,small_sample)
  116. sources:(1,2,big_sample)
  117. """
  118. return sources
  119. def demix_base(self, mixes, margin_size):
  120. chunked_sources = []
  121. progress_bar = tqdm(total=len(mixes))
  122. progress_bar.set_description("Processing")
  123. for mix in mixes:
  124. cmix = mixes[mix]
  125. sources = []
  126. n_sample = cmix.shape[1]
  127. model = self.model_
  128. trim = model.n_fft // 2
  129. gen_size = model.chunk_size - 2 * trim
  130. pad = gen_size - n_sample % gen_size
  131. mix_p = np.concatenate(
  132. (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
  133. )
  134. mix_waves = []
  135. i = 0
  136. while i < n_sample + pad:
  137. waves = np.array(mix_p[:, i : i + model.chunk_size])
  138. mix_waves.append(waves)
  139. i += gen_size
  140. mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu)
  141. with torch.no_grad():
  142. _ort = self.model
  143. spek = model.stft(mix_waves)
  144. if self.args.denoise:
  145. spec_pred = (
  146. -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
  147. + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
  148. )
  149. tar_waves = model.istft(torch.tensor(spec_pred))
  150. else:
  151. tar_waves = model.istft(
  152. torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
  153. )
  154. tar_signal = (
  155. tar_waves[:, :, trim:-trim]
  156. .transpose(0, 1)
  157. .reshape(2, -1)
  158. .numpy()[:, :-pad]
  159. )
  160. start = 0 if mix == 0 else margin_size
  161. end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
  162. if margin_size == 0:
  163. end = None
  164. sources.append(tar_signal[:, start:end])
  165. progress_bar.update(1)
  166. chunked_sources.append(sources)
  167. _sources = np.concatenate(chunked_sources, axis=-1)
  168. # del self.model
  169. progress_bar.close()
  170. return _sources
  171. def prediction(self, m, vocal_root, others_root, format):
  172. os.makedirs(vocal_root, exist_ok=True)
  173. os.makedirs(others_root, exist_ok=True)
  174. basename = os.path.basename(m)
  175. mix, rate = librosa.load(m, mono=False, sr=44100)
  176. if mix.ndim == 1:
  177. mix = np.asfortranarray([mix, mix])
  178. mix = mix.T
  179. sources = self.demix(mix.T)
  180. opt = sources[0].T
  181. if format in ["wav", "flac"]:
  182. sf.write(
  183. "%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate
  184. )
  185. sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
  186. else:
  187. path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
  188. path_other = "%s/%s_others.wav" % (others_root, basename)
  189. sf.write(path_vocal, mix - opt, rate)
  190. sf.write(path_other, opt, rate)
  191. opt_path_vocal = path_vocal[:-4] + ".%s" % format
  192. opt_path_other = path_other[:-4] + ".%s" % format
  193. if os.path.exists(path_vocal):
  194. os.system(
  195. "ffmpeg -i %s -vn %s -q:a 2 -y" % (path_vocal, opt_path_vocal)
  196. )
  197. if os.path.exists(opt_path_vocal):
  198. try:
  199. os.remove(path_vocal)
  200. except:
  201. pass
  202. if os.path.exists(path_other):
  203. os.system(
  204. "ffmpeg -i %s -vn %s -q:a 2 -y" % (path_other, opt_path_other)
  205. )
  206. if os.path.exists(opt_path_other):
  207. try:
  208. os.remove(path_other)
  209. except:
  210. pass
  211. class MDXNetDereverb:
  212. def __init__(self, chunks):
  213. self.onnx = "%s/uvr5_weights/onnx_dereverb_By_FoxJoy"%os.path.dirname(os.path.abspath(__file__))
  214. self.shifts = 10 # 'Predict with randomised equivariant stabilisation'
  215. self.mixing = "min_mag" # ['default','min_mag','max_mag']
  216. self.chunks = chunks
  217. self.margin = 44100
  218. self.dim_t = 9
  219. self.dim_f = 3072
  220. self.n_fft = 6144
  221. self.denoise = True
  222. self.pred = Predictor(self)
  223. self.device = cpu
  224. def _path_audio_(self, input, vocal_root, others_root, format, is_hp3=False):
  225. self.pred.prediction(input, vocal_root, others_root, format)