vocos_tab_wav.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. import torchaudio
  3. import gradio as gr
  4. from tts_webui.utils.list_dir_models import unload_model_button
  5. from tts_webui.vocos.get_vocos_model import get_vocos_model
  6. def vocos_predict(audio: str, bandwidth: int):
  7. vocos = get_vocos_model(model_name="charactr/vocos-encodec-24khz")
  8. bandwidth_id = torch.tensor([bandwidth])
  9. y, sr = torchaudio.load(audio)
  10. if y.size(0) > 1: # mix to mono
  11. y = y.mean(dim=0, keepdim=True)
  12. y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000)
  13. with torch.no_grad():
  14. y_hat = vocos(y, bandwidth_id=bandwidth_id)
  15. return (24000, y_hat.squeeze().numpy())
  16. def vocos_wav_ui():
  17. file_input = gr.Audio(
  18. label="Input Audio",
  19. type="filepath",
  20. sources="upload",
  21. interactive=True,
  22. )
  23. options = [str(x) for x in [1.5, 3.0, 6.0, 12.0]]
  24. bandwidth_id = gr.Dropdown(
  25. value=options[0],
  26. choices=options,
  27. type="index",
  28. label="Bandwidth in kbps",
  29. )
  30. submit = gr.Button(value="Reconstruct")
  31. output = gr.Audio(label="Output Audio")
  32. unload_model_button("vocos")
  33. submit.click(
  34. fn=vocos_predict,
  35. inputs=[file_input, bandwidth_id],
  36. outputs=output,
  37. api_name="vocos_wav",
  38. )
  39. def vocos_tab_wav():
  40. with gr.Tab("Vocos (WAV)"):
  41. vocos_wav_ui()
  42. if __name__ == "__main__":
  43. with gr.Blocks() as demo:
  44. vocos_tab_wav()
  45. demo.launch(server_port=7861)