2
0

demucs_tab.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torchaudio
  2. import torch
  3. import gradio as gr
  4. from tts_webui.utils.manage_model_state import manage_model_state
  5. from tts_webui.utils.list_dir_models import unload_model_button
  6. @manage_model_state("demucs")
  7. def _get_demucs_model(model_name="htdemucs"):
  8. from demucs import pretrained
  9. return pretrained.get_model(model_name)
  10. def apply_demucs(wav, sr):
  11. from demucs.audio import convert_audio
  12. from demucs.apply import apply_model
  13. demucs_model = _get_demucs_model(model_name="htdemucs")
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. wav = convert_audio(wav, sr, demucs_model.samplerate, demucs_model.audio_channels)
  16. return apply_model(demucs_model, wav, device=device)[0] # type: ignore
  17. COMPONENTS = ["drums", "bass", "other", "vocals"]
  18. def demucs_audio(audio):
  19. demucs_model = _get_demucs_model(model_name="htdemucs")
  20. wav, sr = torchaudio.load(audio)
  21. out = apply_demucs(wav=wav.unsqueeze(0), sr=sr)
  22. def to_wav(tensor):
  23. # make mono by picking first channel
  24. tensor = tensor[0]
  25. tensor = tensor.detach().cpu().squeeze().numpy()
  26. tensor = (tensor * 32767).astype("int16")
  27. return tensor
  28. def get_audios(out):
  29. for name, source in zip(demucs_model.sources, out):
  30. yield name, (demucs_model.samplerate, to_wav(source))
  31. audios_dict = dict(get_audios(out))
  32. return [audios_dict.get(key) for key in COMPONENTS]
  33. def demucs_ui():
  34. gr.Markdown(
  35. """
  36. # Demucs
  37. Gradio demo for Demucs: Music Source Separation in the Waveform Domain.
  38. To use it, simply upload your audio, and click "Separate".
  39. """
  40. )
  41. with gr.Row(equal_height=False):
  42. with gr.Column():
  43. demucs_input = gr.Audio(label="Input", type="filepath")
  44. button = gr.Button("Separate")
  45. unload_model_button("demucs")
  46. with gr.Column():
  47. outputs = [gr.Audio(label=label) for label in COMPONENTS]
  48. button.click(
  49. inputs=demucs_input,
  50. outputs=outputs,
  51. fn=demucs_audio,
  52. api_name="demucs",
  53. )
  54. def demucs_tab():
  55. with gr.Tab("Demucs", id="demucs"):
  56. demucs_ui()
  57. if __name__ == "__main__":
  58. if "demo" in locals():
  59. locals()["demo"].close()
  60. with gr.Blocks() as demo:
  61. demucs_tab()
  62. demo.launch(
  63. server_port=7770,
  64. )