musicgen_tab.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import torch
  2. import gradio as gr
  3. from einops import rearrange
  4. from typing import Optional, Tuple
  5. import numpy as np
  6. from tts_webui.decorators.gradio_dict_decorator import dictionarize
  7. from tts_webui.musicgen.audio_array_to_sha256 import audio_array_to_sha256
  8. from tts_webui.utils.randomize_seed import randomize_seed_ui
  9. from tts_webui.history_tab.save_to_favorites import save_to_favorites
  10. from typing import Optional
  11. from importlib.metadata import version
  12. from tts_webui.history_tab.save_to_favorites import save_to_favorites
  13. from tts_webui.utils.list_dir_models import model_select_ui, unload_model_button
  14. from tts_webui.utils.randomize_seed import randomize_seed_ui
  15. from tts_webui.utils.manage_model_state import manage_model_state
  16. from tts_webui.decorators.gradio_dict_decorator import dictionarize
  17. from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
  18. from tts_webui.decorators.decorator_log_generation import decorator_log_generation
  19. from tts_webui.decorators.decorator_save_wav import decorator_save_wav
  20. from tts_webui.decorators.decorator_add_base_filename import (
  21. decorator_add_base_filename,
  22. )
  23. from tts_webui.decorators.decorator_add_date import decorator_add_date
  24. from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
  25. from tts_webui.decorators.log_function_time import log_function_time
  26. from tts_webui.decorators.decorator_save_musicgen_npz import decorator_save_musicgen_npz
  27. from tts_webui.extensions_loader.decorator_extensions import (
  28. decorator_extension_outer,
  29. decorator_extension_inner,
  30. )
  31. from tts_webui.utils.save_json_result import save_json_result
  32. AUDIOCRAFT_VERSION = version("audiocraft")
  33. def melody_to_sha256(melody: Optional[Tuple[int, np.ndarray]]) -> Optional[str]:
  34. if melody is None:
  35. return None
  36. _, audio_array = melody
  37. return audio_array_to_sha256(audio_array)
  38. def _decorator_musicgen_save_metadata(fn):
  39. def wrapper(*args, **kwargs):
  40. result_dict = fn(*args, **kwargs)
  41. audio_array = result_dict["audio_out"][1]
  42. result_dict["metadata"] = {
  43. "_version": "0.0.1",
  44. "_hash_version": "0.0.3",
  45. "_audiocraft_version": AUDIOCRAFT_VERSION,
  46. **kwargs,
  47. "outputs": None,
  48. "models": {},
  49. "hash": audio_array_to_sha256(audio_array),
  50. "date": str(result_dict["date"]),
  51. "melody": melody_to_sha256(kwargs.get("melody", None)),
  52. }
  53. save_json_result(result_dict, result_dict["metadata"])
  54. return result_dict
  55. return wrapper
  56. @manage_model_state("musicgen_audiogen")
  57. def load_model(version):
  58. from audiocraft.models.musicgen import MusicGen
  59. from audiocraft.models.audiogen import AudioGen
  60. if version == "facebook/audiogen-medium":
  61. return AudioGen.get_pretrained(version)
  62. return MusicGen.get_pretrained(version)
  63. @decorator_extension_outer
  64. @decorator_apply_torch_seed
  65. @decorator_save_musicgen_npz
  66. @_decorator_musicgen_save_metadata
  67. @decorator_save_wav
  68. @decorator_add_model_type("musicgen")
  69. @decorator_add_base_filename
  70. @decorator_add_date
  71. @decorator_log_generation
  72. @decorator_extension_inner
  73. @log_function_time
  74. def generate(
  75. text,
  76. melody,
  77. model_name,
  78. duration,
  79. topk,
  80. topp,
  81. temperature,
  82. cfg_coef,
  83. use_multi_band_diffusion,
  84. **kwargs,
  85. ):
  86. model_inst = load_model(model_name)
  87. model_inst.set_generation_params(
  88. use_sampling=True,
  89. top_k=topk,
  90. top_p=topp,
  91. temperature=temperature,
  92. cfg_coef=cfg_coef,
  93. duration=duration,
  94. )
  95. if "melody" in model_name and melody is not None:
  96. sr, melody = (
  97. melody[0],
  98. torch.from_numpy(melody[1]).to(model_inst.device).float().t().unsqueeze(0),
  99. )
  100. print(melody.shape)
  101. if melody.dim() == 2:
  102. melody = melody[None]
  103. melody = melody[..., : int(sr * model_inst.lm.cfg.dataset.segment_duration)] # type: ignore
  104. output, tokens = model_inst.generate_with_chroma(
  105. descriptions=[text],
  106. melody_wavs=melody,
  107. melody_sample_rate=sr,
  108. progress=True,
  109. return_tokens=True,
  110. # generator=generator,
  111. )
  112. elif model_name == "facebook/audiogen-medium":
  113. output = model_inst.generate(
  114. descriptions=[text],
  115. progress=True,
  116. # generator=generator,
  117. )
  118. tokens = None
  119. else:
  120. output, tokens = model_inst.generate(
  121. descriptions=[text],
  122. progress=True,
  123. return_tokens=True,
  124. # generator=generator,
  125. )
  126. if use_multi_band_diffusion:
  127. if model_name != "facebook/audiogen-medium":
  128. from audiocraft.models.multibanddiffusion import MultiBandDiffusion
  129. from audiocraft.models.encodec import InterleaveStereoCompressionModel
  130. mbd = MultiBandDiffusion.get_mbd_musicgen()
  131. if isinstance(
  132. model_inst.compression_model, InterleaveStereoCompressionModel
  133. ):
  134. left, right = model_inst.compression_model.get_left_right_codes(tokens)
  135. tokens = torch.cat([left, right])
  136. outputs_diffusion = mbd.tokens_to_wav(tokens)
  137. if isinstance(
  138. model_inst.compression_model, InterleaveStereoCompressionModel
  139. ):
  140. assert outputs_diffusion.shape[1] == 1 # output is mono
  141. outputs_diffusion = rearrange(
  142. outputs_diffusion, "(s b) c t -> b (s c) t", s=2
  143. )
  144. output = outputs_diffusion.detach().cpu().numpy().squeeze()
  145. else:
  146. print("NOTICE: Multi-band diffusion is not supported for AudioGen")
  147. output = output.detach().cpu().numpy().squeeze()
  148. else:
  149. output = output.detach().cpu().numpy().squeeze()
  150. audio_array = output
  151. if audio_array.shape[0] == 2:
  152. audio_array = np.transpose(audio_array)
  153. return {"audio_out": (model_inst.sample_rate, audio_array), "tokens": tokens}
  154. def musicgen_tab():
  155. with gr.Tab("MusicGen + AudioGen"):
  156. musicgen_ui()
  157. def musicgen_ui():
  158. gr.Markdown(f"""Audiocraft version: {AUDIOCRAFT_VERSION}""")
  159. with gr.Row(equal_height=False):
  160. with gr.Column():
  161. text = gr.Textbox(label="Prompt", lines=3, placeholder="Enter text here...")
  162. model_name = model_select_ui(
  163. [
  164. ("Melody", "facebook/musicgen-melody"),
  165. ("Medium", "facebook/musicgen-medium"),
  166. ("Small", "facebook/musicgen-small"),
  167. ("Large", "facebook/musicgen-large"),
  168. ("Audiogen", "facebook/audiogen-medium"),
  169. ("Melody Large", "facebook/musicgen-melody-large"),
  170. ("Stereo Small", "facebook/musicgen-stereo-small"),
  171. ("Stereo Medium", "facebook/musicgen-stereo-medium"),
  172. ("Stereo Melody", "facebook/musicgen-stereo-melody"),
  173. ("Stereo Large", "facebook/musicgen-stereo-large"),
  174. ("Stereo Melody Large", "facebook/musicgen-stereo-melody-large"),
  175. ],
  176. "musicgen_audiogen",
  177. )
  178. melody = gr.Audio(sources="upload", type="numpy", label="Melody (optional)")
  179. submit = gr.Button("Generate", variant="primary")
  180. with gr.Column():
  181. duration = gr.Slider(minimum=1, maximum=360, value=10, label="Duration")
  182. with gr.Row():
  183. topk = gr.Number(label="Top-k", value=250, interactive=True)
  184. topp = gr.Slider(
  185. minimum=0.0, maximum=1.5, value=0.0, label="Top-p", step=0.05
  186. )
  187. temperature = gr.Slider(
  188. minimum=0.0, maximum=1.5, value=1.0, label="Temperature", step=0.05
  189. )
  190. cfg_coef = gr.Slider(
  191. minimum=0.0,
  192. maximum=10.0,
  193. value=3.0,
  194. label="Classifier Free Guidance",
  195. step=0.1,
  196. )
  197. use_multi_band_diffusion = gr.Checkbox(
  198. label="Use Multi-Band Diffusion (High VRAM Usage)",
  199. value=False,
  200. )
  201. seed, randomize_seed_callback = randomize_seed_ui()
  202. unload_model_button("musicgen_audiogen")
  203. with gr.Column():
  204. audio_out = gr.Audio(label="Generated Music", type="numpy")
  205. with gr.Row():
  206. folder_root = gr.Textbox(visible=False)
  207. save_button = gr.Button("Save to favorites", visible=True)
  208. melody_button = gr.Button("Use as melody", visible=True)
  209. save_button.click(
  210. fn=save_to_favorites,
  211. inputs=[folder_root],
  212. outputs=[save_button],
  213. )
  214. melody_button.click(
  215. fn=lambda melody_in: melody_in,
  216. inputs=[audio_out],
  217. outputs=[melody],
  218. )
  219. input_dict = {
  220. text: "text",
  221. melody: "melody",
  222. model_name: "model_name",
  223. duration: "duration",
  224. topk: "topk",
  225. topp: "topp",
  226. temperature: "temperature",
  227. cfg_coef: "cfg_coef",
  228. seed: "seed",
  229. use_multi_band_diffusion: "use_multi_band_diffusion",
  230. }
  231. output_dict = {
  232. "audio_out": audio_out,
  233. "metadata": gr.JSON(visible=False),
  234. "folder_root": folder_root,
  235. }
  236. submit.click(
  237. **randomize_seed_callback,
  238. ).then(
  239. **dictionarize(
  240. fn=generate,
  241. inputs=input_dict,
  242. outputs=output_dict,
  243. ),
  244. api_name="musicgen",
  245. )
  246. if __name__ == "__main__":
  247. if "demo" in locals():
  248. demo.close() # type: ignore
  249. with gr.Blocks() as demo:
  250. musicgen_tab()
  251. demo.launch(
  252. server_port=7770,
  253. )