123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import torchaudio
- import torch
- import gradio as gr
- from tts_webui.utils.manage_model_state import manage_model_state
- from tts_webui.utils.list_dir_models import unload_model_button
- @manage_model_state("demucs")
- def _get_demucs_model(model_name="htdemucs"):
- from demucs import pretrained
- return pretrained.get_model(model_name)
- def apply_demucs(wav, sr):
- from demucs.audio import convert_audio
- from demucs.apply import apply_model
- demucs_model = _get_demucs_model(model_name="htdemucs")
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- wav = convert_audio(wav, sr, demucs_model.samplerate, demucs_model.audio_channels)
- return apply_model(demucs_model, wav, device=device)[0] # type: ignore
- COMPONENTS = ["drums", "bass", "other", "vocals"]
- def demucs_audio(audio):
- demucs_model = _get_demucs_model(model_name="htdemucs")
- wav, sr = torchaudio.load(audio)
- out = apply_demucs(wav=wav.unsqueeze(0), sr=sr)
- def to_wav(tensor):
- # make mono by picking first channel
- tensor = tensor[0]
- tensor = tensor.detach().cpu().squeeze().numpy()
- tensor = (tensor * 32767).astype("int16")
- return tensor
- def get_audios(out):
- for name, source in zip(demucs_model.sources, out):
- yield name, (demucs_model.samplerate, to_wav(source))
- audios_dict = dict(get_audios(out))
- return [audios_dict.get(key) for key in COMPONENTS]
- def demucs_ui():
- gr.Markdown(
- """
- # Demucs
- Gradio demo for Demucs: Music Source Separation in the Waveform Domain.
-
- To use it, simply upload your audio, and click "Separate".
- """
- )
- with gr.Row(equal_height=False):
- with gr.Column():
- demucs_input = gr.Audio(label="Input", type="filepath")
- button = gr.Button("Separate")
- unload_model_button("demucs")
- with gr.Column():
- outputs = [gr.Audio(label=label) for label in COMPONENTS]
- button.click(
- inputs=demucs_input,
- outputs=outputs,
- fn=demucs_audio,
- api_name="demucs",
- )
- def demucs_tab():
- with gr.Tab("Demucs", id="demucs"):
- demucs_ui()
- if __name__ == "__main__":
- if "demo" in locals():
- locals()["demo"].close()
- with gr.Blocks() as demo:
- demucs_tab()
- demo.launch(
- server_port=7770,
- )
|