magnet_tab.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import gradio as gr
  2. from typing import List, TypedDict, Literal
  3. import numpy as np
  4. from importlib.metadata import version
  5. from tts_webui.utils.only_overlap import only_overlap
  6. from tts_webui.history_tab.save_to_favorites import save_to_favorites
  7. from tts_webui.utils.get_path_from_root import get_path_from_root
  8. from tts_webui.utils.list_dir_models import model_select_ui, unload_model_button
  9. from tts_webui.utils.randomize_seed import randomize_seed_ui
  10. from tts_webui.utils.manage_model_state import manage_model_state, unload_model
  11. from tts_webui.decorators.gradio_dict_decorator import dictionarize
  12. from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
  13. from tts_webui.decorators.decorator_log_generation import decorator_log_generation
  14. from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
  15. from tts_webui.decorators.decorator_save_wav import decorator_save_wav
  16. from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
  17. from tts_webui.decorators.decorator_add_date import decorator_add_date
  18. from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
  19. from tts_webui.decorators.log_function_time import log_function_time
  20. from tts_webui.decorators.decorator_save_musicgen_npz import decorator_save_musicgen_npz
  21. from tts_webui.extensions_loader.decorator_extensions import (
  22. decorator_extension_outer,
  23. decorator_extension_inner,
  24. )
  25. AUDIOCRAFT_VERSION = version("audiocraft")
  26. class MagnetGenerationParams(TypedDict):
  27. use_sampling: bool
  28. top_k: int
  29. top_p: float
  30. temperature: float
  31. max_cfg_coef: float
  32. min_cfg_coef: float
  33. decoding_steps: List[int]
  34. span_arrangement: Literal["nonoverlap", "stride1"]
  35. @manage_model_state("magnet")
  36. def get_model(model):
  37. from audiocraft.models.magnet import MAGNeT
  38. return MAGNeT.get_pretrained(model)
  39. @decorator_extension_outer
  40. @decorator_apply_torch_seed
  41. @decorator_save_musicgen_npz
  42. @decorator_save_metadata
  43. @decorator_save_wav
  44. @decorator_add_model_type("magnet")
  45. @decorator_add_base_filename
  46. @decorator_add_date
  47. @decorator_log_generation
  48. @decorator_extension_inner
  49. @log_function_time
  50. def generate(
  51. model_name: str,
  52. text: str,
  53. decoding_steps_1: int,
  54. decoding_steps_2: int,
  55. decoding_steps_3: int,
  56. decoding_steps_4: int,
  57. **kwargs,
  58. ):
  59. model_inst = get_model(model_name)
  60. model_inst.set_generation_params(
  61. **only_overlap(
  62. {
  63. **kwargs,
  64. "decoding_steps": [
  65. decoding_steps_1,
  66. decoding_steps_2,
  67. decoding_steps_3,
  68. decoding_steps_4,
  69. ],
  70. },
  71. MagnetGenerationParams,
  72. )
  73. )
  74. output, tokens = model_inst.generate(
  75. descriptions=[text],
  76. progress=True,
  77. return_tokens=True,
  78. )
  79. audio_array = output.detach().cpu().numpy().squeeze()
  80. stereo = audio_array.shape[0] == 2
  81. if stereo:
  82. audio_array = np.transpose(audio_array)
  83. return {"audio_out": (model_inst.sample_rate, audio_array), "tokens": tokens}
  84. def magnet_tab():
  85. with gr.Tab("Magnet"):
  86. magnet_ui()
  87. def magnet_ui():
  88. gr.Markdown(f"""Audiocraft version: {AUDIOCRAFT_VERSION}""")
  89. with gr.Row(equal_height=False):
  90. with gr.Column():
  91. text = gr.Textbox(label="Prompt", lines=3, placeholder="Enter text here...")
  92. model_name = model_select_ui(
  93. [
  94. ("Magnet, 10s, Small", "facebook/magnet-small-10secs"),
  95. ("Magnet, 10s, Medium", "facebook/magnet-medium-10secs"),
  96. ("Magnet, 30s, Small", "facebook/magnet-small-30secs"),
  97. ("Magnet, 30s, Medium", "facebook/magnet-medium-30secs"),
  98. ("Audio, Magnet, Small", "facebook/audio-magnet-small"),
  99. ("Audio, Magnet, Medium", "facebook/audio-magnet-medium"),
  100. ],
  101. "magnet",
  102. )
  103. unload_model_button("magnet")
  104. submit = gr.Button("Generate", variant="primary")
  105. with gr.Column():
  106. with gr.Row():
  107. top_k = gr.Number(label="Top-k", value=0)
  108. top_p = gr.Slider(
  109. minimum=0.0,
  110. maximum=1.5,
  111. value=0.95,
  112. label="Top-p",
  113. step=0.05,
  114. )
  115. temperature = gr.Slider(
  116. minimum=0.0,
  117. maximum=10,
  118. value=1.0,
  119. label="Temperature",
  120. step=0.05,
  121. )
  122. with gr.Row():
  123. min_cfg_coef = gr.Slider(
  124. label="Min CFG coefficient",
  125. value=1.0,
  126. minimum=0,
  127. step=0.5,
  128. )
  129. max_cfg_coef = gr.Slider(
  130. label="Max CFG coefficient",
  131. value=20.0,
  132. minimum=0,
  133. step=0.5,
  134. )
  135. with gr.Row():
  136. gr.Markdown("Decoding Steps:")
  137. decoding_steps_1 = gr.Slider(label="Stage 1", value=80)
  138. decoding_steps_2 = gr.Slider(label="Stage 2", value=40)
  139. decoding_steps_3 = gr.Slider(label="Stage 3", value=40)
  140. decoding_steps_4 = gr.Slider(label="Stage 4", value=20)
  141. with gr.Row():
  142. span_arrangement = gr.Radio(
  143. ["nonoverlap", "stride1"],
  144. label="Span Scoring",
  145. value="nonoverlap",
  146. )
  147. use_sampling = gr.Checkbox(label="Use Sampling", value=True)
  148. seed, randomize_seed_callback = randomize_seed_ui()
  149. with gr.Column():
  150. audio_out = gr.Audio(label="Generated Music", type="numpy")
  151. with gr.Row():
  152. folder_root = gr.Textbox(visible=False)
  153. save_button = gr.Button("Save to favorites", visible=True)
  154. save_button.click(
  155. fn=save_to_favorites,
  156. inputs=[folder_root],
  157. outputs=[save_button],
  158. )
  159. inputs_dict = {
  160. model_name: "model_name",
  161. text: "text",
  162. seed: "seed",
  163. use_sampling: "use_sampling",
  164. top_k: "top_k",
  165. top_p: "top_p",
  166. temperature: "temperature",
  167. max_cfg_coef: "max_cfg_coef",
  168. min_cfg_coef: "min_cfg_coef",
  169. decoding_steps_1: "decoding_steps_1",
  170. decoding_steps_2: "decoding_steps_2",
  171. decoding_steps_3: "decoding_steps_3",
  172. decoding_steps_4: "decoding_steps_4",
  173. span_arrangement: "span_arrangement",
  174. }
  175. outputs_dict = {
  176. "audio_out": audio_out,
  177. "metadata": gr.JSON(label="Metadata", visible=False),
  178. "folder_root": folder_root,
  179. }
  180. submit.click(
  181. **randomize_seed_callback,
  182. ).then(
  183. **dictionarize(
  184. fn=generate,
  185. inputs=inputs_dict,
  186. outputs=outputs_dict,
  187. ),
  188. api_name="magnet",
  189. )
  190. if __name__ == "__main__":
  191. with gr.Blocks() as demo:
  192. magnet_tab()
  193. demo.launch(
  194. server_port=7770,
  195. )