main.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from tts_webui.utils.save_waveform_plot import plot_waveform_as_image
  2. import gradio as gr
  3. import torchaudio
  4. import torch
  5. def extension__tts_generation_webui():
  6. simple_remixer_ui()
  7. return {
  8. "package_name": "extension_simple_remixer",
  9. "name": "Simple Remixer",
  10. "version": "0.0.1",
  11. "requirements": "git+https://github.com/rsxdalv/extension_simple_remixer@main",
  12. "description": "Simple remixer allows concatenating multiple audio files and mixing them together.",
  13. "extension_type": "interface",
  14. "extension_class": "tools",
  15. "author": "rsxdalv",
  16. "extension_author": "rsxdalv",
  17. "license": "MIT",
  18. "website": "https://github.com/rsxdalv/extension_simple_remixer",
  19. "extension_website": "https://github.com/rsxdalv/extension_simple_remixer",
  20. "extension_platform_version": "0.0.1",
  21. }
  22. def gr_mini_button(value, **kwargs):
  23. return gr.Button(
  24. value,
  25. elem_classes="btn-sm material-symbols-outlined",
  26. size="sm",
  27. **kwargs,
  28. )
  29. def simple_remixer_ui():
  30. input_audio = gr.Audio(label="Input Audio")
  31. def create_slot(id=0):
  32. with gr.Group():
  33. audio = gr.Audio(label=f"Slot {str(id)}")
  34. with gr.Row():
  35. clear = gr_mini_button("delete").click(
  36. fn=lambda: [gr.Audio(None)],
  37. outputs=[audio],
  38. )
  39. copy_from_input = gr_mini_button("keyboard_return").click(
  40. fn=lambda input_value: [gr.Audio(input_value)],
  41. inputs=[input_audio],
  42. outputs=[audio],
  43. )
  44. return audio
  45. def slot_stack(i):
  46. with gr.Column(variant="compact"):
  47. a = create_slot(i)
  48. b = create_slot(i)
  49. c = create_slot(i)
  50. return a, b, c
  51. with gr.Row():
  52. slots = [slot_stack(i) for i in range(3)]
  53. slots = [x for y in slots for x in y]
  54. concat = gr.Button("Concatenate")
  55. output_audio = gr.Audio(label="Output Audio")
  56. def concat_audio(*slot_audios):
  57. sample_rate = max(x[0] for x in slot_audios if x is not None)
  58. resampled_audios = [
  59. resample_from_to(x[0], sample_rate, x[1]) if x is not None else None
  60. for x in slot_audios
  61. ]
  62. stacked_audios = [
  63. resampled_audios[i : i + 3] for i in range(0, len(slot_audios), 3)
  64. ]
  65. def mix_audio(x):
  66. non_null_audios = [i for i in x if i is not None]
  67. if not non_null_audios:
  68. return None
  69. max_len = max(i.shape[0] for i in non_null_audios)
  70. stack = torch.stack(
  71. [
  72. torch.nn.functional.pad(
  73. i,
  74. (
  75. 0,
  76. max_len - i.shape[0],
  77. ),
  78. )
  79. for i in non_null_audios
  80. ]
  81. )
  82. return torch.sum(
  83. stack,
  84. dim=0,
  85. )
  86. merged_audios = [mix_audio(x) for x in stacked_audios]
  87. if non_null_audios := [x for x in merged_audios if x is not None]:
  88. return gr.Audio(
  89. (sample_rate, torch.cat(non_null_audios).cpu().numpy())
  90. )
  91. else:
  92. return gr.Audio(None)
  93. def resample_from_to(in_sr: int, out_sr: int, in_wav):
  94. return torchaudio.transforms.Resample(in_sr, out_sr)(
  95. torch.from_numpy(in_wav).float()
  96. )
  97. concat.click(
  98. fn=concat_audio,
  99. inputs=slots, # type: ignore
  100. outputs=output_audio,
  101. )
  102. send_to_input = gr.Button("Send to input")
  103. send_to_input.click(
  104. fn=lambda x: gr.Audio(x),
  105. inputs=output_audio,
  106. outputs=input_audio,
  107. )
  108. return input_audio
  109. def simple_remixer_tab():
  110. with gr.Tab("Simple Remixer", id="simple_remixer"):
  111. return simple_remixer_ui()