123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- import gradio as gr
- from typing import List, TypedDict, Literal
- import numpy as np
- from importlib.metadata import version
- from tts_webui.utils.only_overlap import only_overlap
- from tts_webui.history_tab.save_to_favorites import save_to_favorites
- from tts_webui.utils.get_path_from_root import get_path_from_root
- from tts_webui.utils.list_dir_models import model_select_ui, unload_model_button
- from tts_webui.utils.randomize_seed import randomize_seed_ui
- from tts_webui.utils.manage_model_state import manage_model_state, unload_model
- from tts_webui.decorators.gradio_dict_decorator import dictionarize
- from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
- from tts_webui.decorators.decorator_log_generation import decorator_log_generation
- from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
- from tts_webui.decorators.decorator_save_wav import decorator_save_wav
- from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
- from tts_webui.decorators.decorator_add_date import decorator_add_date
- from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
- from tts_webui.decorators.log_function_time import log_function_time
- from tts_webui.decorators.decorator_save_musicgen_npz import decorator_save_musicgen_npz
- from tts_webui.extensions_loader.decorator_extensions import (
- decorator_extension_outer,
- decorator_extension_inner,
- )
- AUDIOCRAFT_VERSION = version("audiocraft")
- class MagnetGenerationParams(TypedDict):
- use_sampling: bool
- top_k: int
- top_p: float
- temperature: float
- max_cfg_coef: float
- min_cfg_coef: float
- decoding_steps: List[int]
- span_arrangement: Literal["nonoverlap", "stride1"]
- @manage_model_state("magnet")
- def get_model(model):
- from audiocraft.models.magnet import MAGNeT
- return MAGNeT.get_pretrained(model)
- @decorator_extension_outer
- @decorator_apply_torch_seed
- @decorator_save_musicgen_npz
- @decorator_save_metadata
- @decorator_save_wav
- @decorator_add_model_type("magnet")
- @decorator_add_base_filename
- @decorator_add_date
- @decorator_log_generation
- @decorator_extension_inner
- @log_function_time
- def generate(
- model_name: str,
- text: str,
- decoding_steps_1: int,
- decoding_steps_2: int,
- decoding_steps_3: int,
- decoding_steps_4: int,
- **kwargs,
- ):
- model_inst = get_model(model_name)
- model_inst.set_generation_params(
- **only_overlap(
- {
- **kwargs,
- "decoding_steps": [
- decoding_steps_1,
- decoding_steps_2,
- decoding_steps_3,
- decoding_steps_4,
- ],
- },
- MagnetGenerationParams,
- )
- )
- output, tokens = model_inst.generate(
- descriptions=[text],
- progress=True,
- return_tokens=True,
- )
- audio_array = output.detach().cpu().numpy().squeeze()
- stereo = audio_array.shape[0] == 2
- if stereo:
- audio_array = np.transpose(audio_array)
- return {"audio_out": (model_inst.sample_rate, audio_array), "tokens": tokens}
- def magnet_tab():
- with gr.Tab("Magnet"):
- magnet_ui()
- def magnet_ui():
- gr.Markdown(f"""Audiocraft version: {AUDIOCRAFT_VERSION}""")
- with gr.Row(equal_height=False):
- with gr.Column():
- text = gr.Textbox(label="Prompt", lines=3, placeholder="Enter text here...")
- model_name = model_select_ui(
- [
- ("Magnet, 10s, Small", "facebook/magnet-small-10secs"),
- ("Magnet, 10s, Medium", "facebook/magnet-medium-10secs"),
- ("Magnet, 30s, Small", "facebook/magnet-small-30secs"),
- ("Magnet, 30s, Medium", "facebook/magnet-medium-30secs"),
- ("Audio, Magnet, Small", "facebook/audio-magnet-small"),
- ("Audio, Magnet, Medium", "facebook/audio-magnet-medium"),
- ],
- "magnet",
- )
- unload_model_button("magnet")
- submit = gr.Button("Generate", variant="primary")
- with gr.Column():
- with gr.Row():
- top_k = gr.Number(label="Top-k", value=0)
- top_p = gr.Slider(
- minimum=0.0,
- maximum=1.5,
- value=0.95,
- label="Top-p",
- step=0.05,
- )
- temperature = gr.Slider(
- minimum=0.0,
- maximum=10,
- value=1.0,
- label="Temperature",
- step=0.05,
- )
- with gr.Row():
- min_cfg_coef = gr.Slider(
- label="Min CFG coefficient",
- value=1.0,
- minimum=0,
- step=0.5,
- )
- max_cfg_coef = gr.Slider(
- label="Max CFG coefficient",
- value=20.0,
- minimum=0,
- step=0.5,
- )
- with gr.Row():
- gr.Markdown("Decoding Steps:")
- decoding_steps_1 = gr.Slider(label="Stage 1", value=80)
- decoding_steps_2 = gr.Slider(label="Stage 2", value=40)
- decoding_steps_3 = gr.Slider(label="Stage 3", value=40)
- decoding_steps_4 = gr.Slider(label="Stage 4", value=20)
- with gr.Row():
- span_arrangement = gr.Radio(
- ["nonoverlap", "stride1"],
- label="Span Scoring",
- value="nonoverlap",
- )
- use_sampling = gr.Checkbox(label="Use Sampling", value=True)
- seed, randomize_seed_callback = randomize_seed_ui()
- with gr.Column():
- audio_out = gr.Audio(label="Generated Music", type="numpy")
- with gr.Row():
- folder_root = gr.Textbox(visible=False)
- save_button = gr.Button("Save to favorites", visible=True)
- save_button.click(
- fn=save_to_favorites,
- inputs=[folder_root],
- outputs=[save_button],
- )
- inputs_dict = {
- model_name: "model_name",
- text: "text",
- seed: "seed",
- use_sampling: "use_sampling",
- top_k: "top_k",
- top_p: "top_p",
- temperature: "temperature",
- max_cfg_coef: "max_cfg_coef",
- min_cfg_coef: "min_cfg_coef",
- decoding_steps_1: "decoding_steps_1",
- decoding_steps_2: "decoding_steps_2",
- decoding_steps_3: "decoding_steps_3",
- decoding_steps_4: "decoding_steps_4",
- span_arrangement: "span_arrangement",
- }
- outputs_dict = {
- "audio_out": audio_out,
- "metadata": gr.JSON(label="Metadata", visible=False),
- "folder_root": folder_root,
- }
- submit.click(
- **randomize_seed_callback,
- ).then(
- **dictionarize(
- fn=generate,
- inputs=inputs_dict,
- outputs=outputs_dict,
- ),
- api_name="magnet",
- )
- if __name__ == "__main__":
- with gr.Blocks() as demo:
- magnet_tab()
- demo.launch(
- server_port=7770,
- )
|