vocos_tab_bark.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import torch
  2. import torchaudio
  3. import gradio as gr
  4. import tempfile
  5. from tts_webui.bark.npz_tools import load_npz
  6. from tts_webui.vocos.get_vocos_model import get_vocos_model
  7. def reconstruct_with_vocos(audio_tokens):
  8. device = "cpu"
  9. vocos = get_vocos_model(model_name="charactr/vocos-encodec-24khz")
  10. audio_tokens_torch = torch.from_numpy(audio_tokens).to(device)
  11. features = vocos.codes_to_features(audio_tokens_torch)
  12. bitrate_id_for_6kbps = 2
  13. return vocos.decode(
  14. features, bandwidth_id=torch.tensor([bitrate_id_for_6kbps], device=device)
  15. )
  16. def vocos_predict(npz_file: tempfile._TemporaryFileWrapper):
  17. if npz_file is None:
  18. print("No file selected")
  19. return None
  20. full_generation = load_npz(npz_file.name)
  21. audio_tokens = full_generation["fine_prompt"]
  22. vocos_output = reconstruct_with_vocos(audio_tokens)
  23. vocos_output = upsample_to_44100(vocos_output)
  24. return (44100, vocos_output.cpu().squeeze().numpy())
  25. def upsample_to_44100(audio):
  26. return torchaudio.functional.resample(audio, orig_freq=24000, new_freq=44100)
  27. def get_audio(npz_file: tempfile._TemporaryFileWrapper):
  28. from tts_webui.bark.get_audio_from_npz import get_audio_from_full_generation
  29. if npz_file is None:
  30. print("No file selected")
  31. return [None, None]
  32. full_generation = load_npz(npz_file.name)
  33. return [
  34. get_audio_from_full_generation(full_generation), # type: ignore
  35. None,
  36. ]
  37. def vocos_bark_ui():
  38. npz_file = gr.File(label="Input NPZ", file_types=[".npz"], interactive=True)
  39. submit = gr.Button(value="Decode")
  40. current = gr.Audio(label="decoded with Encodec:")
  41. output = gr.Audio(label="decoded with Vocos:")
  42. npz_file.change(
  43. fn=get_audio,
  44. inputs=[npz_file],
  45. outputs=[current, output],
  46. api_name="encodec_decode",
  47. )
  48. submit.click(
  49. fn=vocos_predict,
  50. inputs=[npz_file],
  51. outputs=output,
  52. api_name="vocos_npz",
  53. )
  54. def vocos_tab_bark():
  55. with gr.Tab("Vocos (Bark NPZ)"):
  56. vocos_bark_ui()
  57. if __name__ == "__main__":
  58. with gr.Blocks() as demo:
  59. vocos_tab_bark()
  60. demo.launch(server_port=7863)