123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656 |
- import os
- import json
- import numpy as np
- import torch
- import gradio as gr
- from huggingface_hub import hf_hub_download
- from tts_webui.history_tab.open_folder import open_folder
- from tts_webui.utils.get_path_from_root import get_path_from_root
- from tts_webui.utils.torch_clear_memory import torch_clear_memory
- from tts_webui.utils.prompt_to_title import prompt_to_title
- from tts_webui.tortoise.gr_reload_button import gr_open_button_simple, gr_reload_button
- LOCAL_DIR_BASE = os.path.join("data", "models", "stable-audio")
- LOCAL_DIR_BASE_ABSOLUTE = get_path_from_root(*LOCAL_DIR_BASE.split(os.path.sep))
- OUTPUT_DIR = os.path.join("outputs-rvc", "Stable Audio")
- def generate_cond_lazy(
- prompt,
- negative_prompt=None,
- seconds_start=0,
- seconds_total=30,
- cfg_scale=6.0,
- steps=250,
- preview_every=None,
- seed=-1,
- sampler_type="dpmpp-3m-sde",
- sigma_min=0.03,
- sigma_max=1000,
- cfg_rescale=0.0,
- use_init=False,
- init_audio=None,
- init_noise_level=1.0,
- mask_cropfrom=None,
- mask_pastefrom=None,
- mask_pasteto=None,
- mask_maskstart=None,
- mask_maskend=None,
- mask_softnessL=None,
- mask_softnessR=None,
- mask_marination=None,
- batch_size=1,
- ):
- from stable_audio_tools.interface.gradio import generate_cond, model
- if model is None:
- gr.Error("Model not loaded")
- raise Exception("Model not loaded")
- return generate_cond(
- prompt=prompt,
- negative_prompt=negative_prompt,
- seconds_start=seconds_start,
- seconds_total=seconds_total,
- cfg_scale=cfg_scale,
- steps=steps,
- preview_every=preview_every,
- seed=seed,
- sampler_type=sampler_type,
- sigma_min=sigma_min,
- sigma_max=sigma_max,
- cfg_rescale=cfg_rescale,
- use_init=use_init,
- init_audio=init_audio,
- init_noise_level=init_noise_level,
- mask_cropfrom=mask_cropfrom,
- mask_pastefrom=mask_pastefrom,
- mask_pasteto=mask_pasteto,
- mask_maskstart=mask_maskstart,
- mask_maskend=mask_maskend,
- mask_softnessL=mask_softnessL,
- mask_softnessR=mask_softnessR,
- mask_marination=mask_marination,
- batch_size=batch_size,
- )
- def get_local_dir(name):
- return os.path.join(LOCAL_DIR_BASE, name.replace("/", "__"))
- def get_config_path(name):
- return os.path.join(get_local_dir(name), "model_config.json")
- def get_ckpt_path(name):
- # check if model.safetensors exists, if not, check if model.ckpt exists
- safetensor_path = os.path.join(get_local_dir(name), "model.safetensors")
- if os.path.exists(safetensor_path):
- return safetensor_path
- else:
- chkpt_path = os.path.join(get_local_dir(name), "model.ckpt")
- if os.path.exists(chkpt_path):
- return chkpt_path
- else:
- raise Exception(
- f"Neither model.safetensors nor model.ckpt exists for {name}"
- )
- def download_pretrained_model(name: str, token: str):
- local_dir = get_local_dir(name)
- model_config_path = hf_hub_download(
- name,
- filename="model_config.json",
- repo_type="model",
- local_dir=local_dir,
- local_dir_use_symlinks=False,
- token=token,
- )
- # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
- try:
- print(f"Downloading {name} model.safetensors")
- ckpt_path = hf_hub_download(
- name,
- filename="model.safetensors",
- repo_type="model",
- local_dir=local_dir,
- local_dir_use_symlinks=False,
- token=token,
- )
- except Exception as e:
- print(f"Downloading {name} model.ckpt")
- ckpt_path = hf_hub_download(
- name,
- filename="model.ckpt",
- repo_type="model",
- local_dir=local_dir,
- local_dir_use_symlinks=False,
- token=token,
- )
- return model_config_path, ckpt_path
- def get_model_list():
- try:
- return [
- x
- for x in os.listdir(LOCAL_DIR_BASE)
- if os.path.isdir(os.path.join(LOCAL_DIR_BASE, x))
- ]
- except FileNotFoundError as e:
- print(e)
- return []
- def load_model_config(model_name):
- path = get_config_path(model_name)
- try:
- with open(path) as f:
- return json.load(f)
- except Exception as e:
- print(e)
- message = (
- f"Model config not found at {path}. Please ensure model_config.json exists."
- )
- gr.Error(message)
- raise Exception(message)
- def stable_audio_ui():
- default_model_config_path = os.path.join(LOCAL_DIR_BASE, "diffusion_cond.json")
- with open(default_model_config_path) as f:
- model_config = json.load(f)
- pretransform_ckpt_path = None
- pretrained_name = None
- def load_model_helper(model_name, model_half):
- if model_name == None:
- return model_name
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- from stable_audio_tools.interface.gradio import load_model
- _, model_config_new = load_model(
- model_config=load_model_config(model_name),
- model_ckpt_path=get_ckpt_path(model_name),
- pretrained_name=None,
- pretransform_ckpt_path=pretransform_ckpt_path,
- model_half=model_half,
- device=device, # type: ignore
- )
- model_type = model_config_new["model_type"] # type: ignore
- if model_type != "diffusion_cond":
- gr.Error("Only diffusion_cond models are supported")
- raise Exception("Only diffusion_cond models are supported")
- # if model_type == "diffusion_cond":
- # ui = create_txt2audio_ui(model_config)
- # elif model_type == "diffusion_uncond":
- # ui = create_diffusion_uncond_ui(model_config)
- # elif model_type == "autoencoder" or model_type == "diffusion_autoencoder":
- # ui = create_autoencoder_ui(model_config)
- # elif model_type == "diffusion_prior":
- # ui = create_diffusion_prior_ui(model_config)
- # elif model_type == "lm":
- # ui = create_lm_ui(model_config)
- return model_name
- def model_select_ui():
- with gr.Row():
- with gr.Column():
- with gr.Row():
- model_select = gr.Dropdown(
- choices=get_model_list(), # type: ignore
- label="Model",
- value=pretrained_name,
- )
- gr_open_button_simple(LOCAL_DIR_BASE, api_name="stable_audio_open_models")
- gr_reload_button().click(
- fn=lambda: gr.Dropdown(choices=get_model_list()),
- outputs=[model_select],
- api_name="stable_audio_refresh_models",
- )
- load_model_button = gr.Button(value="Load model")
- with gr.Column():
- gr.Markdown(
- """
- Stable Audio requires a manual download of a model.
- Please download a model using the download tab or manually place it in the `data/models/stable-audio` folder.
- Note: Due to a [bug](https://github.com/Stability-AI/stable-audio-tools/issues/80) when using half precision
- the model will fail to generate with "init audio" or during "inpainting".
- """
- )
- half_checkbox = gr.Checkbox(
- label="Use half precision when loading the model",
- value=True,
- )
- load_model_button.click(
- fn=load_model_helper,
- inputs=[model_select, half_checkbox],
- outputs=[model_select],
- )
- model_select_ui()
- with gr.Tabs():
- with gr.Tab("Generation"):
- create_sampling_ui(model_config)
- open_dir_btn = gr.Button("Open outputs folder")
- open_dir_btn.click(
- lambda: open_folder(OUTPUT_DIR),
- api_name="stable_audio_open_output_dir",
- )
- with gr.Tab("Inpainting"):
- create_sampling_ui(model_config, inpainting=True)
- open_dir_btn = gr.Button("Open outputs folder")
- open_dir_btn.click(lambda: open_folder(OUTPUT_DIR))
- with gr.Tab("Model Download"):
- model_download_ui()
- def model_download_ui():
- gr.Markdown("""
- Models can be found on the [HuggingFace model hub](https://huggingface.co/models?search=stable-audio-open-1.0).
- Recommended models:
- voices: RoyalCities/Vocal_Textures_Main
- piano: RoyalCities/RC_Infinite_Pianos
- original: stabilityai/stable-audio-open-1.0
- """)
- pretrained_name_text = gr.Textbox(
- label="HuggingFace repo name, e.g. stabilityai/stable-audio-open-1.0",
- value="",
- )
- token_text = gr.Textbox(
- label="HuggingFace Token (Optional, but needed for some non-public models)",
- placeholder="hf_nFjKuKLJF...",
- value="",
- )
- download_btn = gr.Button("Download")
- download_btn.click(
- download_pretrained_model,
- inputs=[pretrained_name_text, token_text],
- outputs=[pretrained_name_text],
- api_name="model_download",
- )
- gr.Markdown(
- "Models can also be downloaded manually and placed within the directory in a folder, for example `data/models/stable-audio/my_model`"
- )
- open_dir_btn = gr.Button("Open local models dir")
- open_dir_btn.click(
- lambda: open_folder(LOCAL_DIR_BASE_ABSOLUTE),
- api_name="model_open_dir",
- )
- def stable_audio_tab():
- with gr.Tab("Stable Audio"):
- stable_audio_ui()
- import scipy.io.wavfile as wavfile
- from tts_webui.utils.date import get_date_string
- def save_result(audio, *generation_args):
- date = get_date_string()
- generation_args = {
- "date": date,
- "version": "0.0.1",
- "prompt": generation_args[0],
- "negative_prompt": generation_args[1],
- "seconds_start_slider": generation_args[2],
- "seconds_total_slider": generation_args[3],
- "cfg_scale_slider": generation_args[4],
- "steps_slider": generation_args[5],
- "preview_every_slider": generation_args[6],
- "seed_textbox": generation_args[7],
- "sampler_type_dropdown": generation_args[8],
- "sigma_min_slider": generation_args[9],
- "sigma_max_slider": generation_args[10],
- "cfg_rescale_slider": generation_args[11],
- "init_audio_checkbox": generation_args[12],
- "init_audio_input": generation_args[13],
- "init_noise_level_slider": generation_args[14],
- }
- print(generation_args)
- prompt = generation_args["prompt"]
- name = f"{date}_{prompt_to_title(prompt)}"
- base_dir = os.path.join(OUTPUT_DIR, name)
- os.makedirs(base_dir, exist_ok=True)
- sr, data = audio
- wavfile.write(os.path.join(base_dir, f"{name}.wav"), sr, data)
- with open(os.path.join(base_dir, f"{name}.json"), "w") as outfile:
- json.dump(
- generation_args,
- outfile,
- indent=2,
- default=lambda o: "<not serializable>",
- )
- sample_rate = 32000
- sample_size = 1920000
- def create_sampling_ui(model_config, inpainting=False):
- with gr.Row():
- with gr.Column(scale=6):
- text = gr.Textbox(show_label=False, placeholder="Prompt")
- negative_prompt = gr.Textbox(
- show_label=False, placeholder="Negative prompt"
- )
- generate_button = gr.Button("Generate", variant="primary", scale=1)
- model_conditioning_config = model_config["model"].get("conditioning", None)
- has_seconds_start = False
- has_seconds_total = False
- if model_conditioning_config is not None:
- for conditioning_config in model_conditioning_config["configs"]:
- if conditioning_config["id"] == "seconds_start":
- has_seconds_start = True
- if conditioning_config["id"] == "seconds_total":
- has_seconds_total = True
- with gr.Row(equal_height=False):
- with gr.Column():
- with gr.Row(visible=has_seconds_start or has_seconds_total):
- # Timing controls
- seconds_start_slider = gr.Slider(
- minimum=0,
- maximum=512,
- step=1,
- value=0,
- label="Seconds start",
- visible=has_seconds_start,
- )
- seconds_total_slider = gr.Slider(
- minimum=0,
- maximum=512,
- step=1,
- value=sample_size // sample_rate,
- label="Seconds total",
- visible=has_seconds_total,
- )
- with gr.Row():
- # Steps slider
- steps_slider = gr.Slider(
- minimum=1, maximum=500, step=1, value=100, label="Steps"
- )
- # Preview Every slider
- preview_every_slider = gr.Slider(
- minimum=0, maximum=100, step=1, value=0, label="Preview Every"
- )
- # CFG scale
- cfg_scale_slider = gr.Slider(
- minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale"
- )
- with gr.Accordion("Sampler params", open=False):
- # Seed
- seed_textbox = gr.Textbox(label="Seed", value="-1")
- CUSTOM_randomize_seed_checkbox = gr.Checkbox(
- label="Randomize seed", value=True
- )
- # Sampler params
- with gr.Row():
- sampler_type_dropdown = gr.Dropdown(
- [
- "dpmpp-2m-sde",
- "dpmpp-3m-sde",
- "k-heun",
- "k-lms",
- "k-dpmpp-2s-ancestral",
- "k-dpm-2",
- "k-dpm-fast",
- ],
- label="Sampler type",
- value="dpmpp-3m-sde",
- )
- sigma_min_slider = gr.Slider(
- minimum=0.0,
- maximum=2.0,
- step=0.01,
- value=0.03,
- label="Sigma min",
- )
- sigma_max_slider = gr.Slider(
- minimum=0.0,
- maximum=1000.0,
- step=0.1,
- value=500,
- label="Sigma max",
- )
- cfg_rescale_slider = gr.Slider(
- minimum=0.0,
- maximum=1,
- step=0.01,
- value=0.0,
- label="CFG rescale amount",
- )
- if inpainting:
- # Inpainting Tab
- with gr.Accordion("Inpainting", open=False):
- sigma_max_slider.maximum = 1000
- init_audio_checkbox = gr.Checkbox(label="Do inpainting")
- init_audio_input = gr.Audio(label="Init audio")
- init_noise_level_slider = gr.Slider(
- minimum=0.1,
- maximum=100.0,
- step=0.1,
- value=80,
- label="Init audio noise level",
- visible=False,
- ) # hide this
- mask_cropfrom_slider = gr.Slider(
- minimum=0.0,
- maximum=100.0,
- step=0.1,
- value=0,
- label="Crop From %",
- )
- mask_pastefrom_slider = gr.Slider(
- minimum=0.0,
- maximum=100.0,
- step=0.1,
- value=0,
- label="Paste From %",
- )
- mask_pasteto_slider = gr.Slider(
- minimum=0.0,
- maximum=100.0,
- step=0.1,
- value=100,
- label="Paste To %",
- )
- mask_maskstart_slider = gr.Slider(
- minimum=0.0,
- maximum=100.0,
- step=0.1,
- value=50,
- label="Mask Start %",
- )
- mask_maskend_slider = gr.Slider(
- minimum=0.0,
- maximum=100.0,
- step=0.1,
- value=100,
- label="Mask End %",
- )
- mask_softnessL_slider = gr.Slider(
- minimum=0.0,
- maximum=100.0,
- step=0.1,
- value=0,
- label="Softmask Left Crossfade Length %",
- )
- mask_softnessR_slider = gr.Slider(
- minimum=0.0,
- maximum=100.0,
- step=0.1,
- value=0,
- label="Softmask Right Crossfade Length %",
- )
- mask_marination_slider = gr.Slider(
- minimum=0.0,
- maximum=1,
- step=0.0001,
- value=0,
- label="Marination level",
- visible=False,
- ) # still working on the usefulness of this
- inputs = [
- text,
- negative_prompt,
- seconds_start_slider,
- seconds_total_slider,
- cfg_scale_slider,
- steps_slider,
- preview_every_slider,
- seed_textbox,
- sampler_type_dropdown,
- sigma_min_slider,
- sigma_max_slider,
- cfg_rescale_slider,
- init_audio_checkbox,
- init_audio_input,
- init_noise_level_slider,
- mask_cropfrom_slider,
- mask_pastefrom_slider,
- mask_pasteto_slider,
- mask_maskstart_slider,
- mask_maskend_slider,
- mask_softnessL_slider,
- mask_softnessR_slider,
- mask_marination_slider,
- ]
- else:
- # Default generation tab
- with gr.Accordion("Init audio", open=False):
- init_audio_checkbox = gr.Checkbox(label="Use init audio")
- init_audio_input = gr.Audio(label="Init audio")
- init_noise_level_slider = gr.Slider(
- minimum=0.1,
- maximum=100.0,
- step=0.01,
- value=0.1,
- label="Init noise level",
- )
- inputs = [
- text,
- negative_prompt,
- seconds_start_slider,
- seconds_total_slider,
- cfg_scale_slider,
- steps_slider,
- preview_every_slider,
- seed_textbox,
- sampler_type_dropdown,
- sigma_min_slider,
- sigma_max_slider,
- cfg_rescale_slider,
- init_audio_checkbox,
- init_audio_input,
- init_noise_level_slider,
- ]
- with gr.Column():
- audio_output = gr.Audio(label="Output audio", interactive=False)
- audio_spectrogram_output = gr.Gallery(
- label="Output spectrogram", show_label=False
- )
- send_to_init_button = gr.Button("Send to init audio", scale=1)
- send_to_init_button.click(
- fn=lambda audio: audio,
- inputs=[audio_output],
- outputs=[init_audio_input],
- )
- def randomize_seed(seed, randomize_seed):
- if randomize_seed:
- return np.random.randint(0, 2**32 - 1, dtype=np.uint32)
- else:
- return int(seed)
- generate_button.click(
- fn=randomize_seed,
- inputs=[seed_textbox, CUSTOM_randomize_seed_checkbox],
- outputs=[seed_textbox],
- ).then(
- fn=generate_cond_lazy,
- inputs=inputs,
- outputs=[audio_output, audio_spectrogram_output],
- api_name="stable_audio_inpaint" if inpainting else "stable_audio_generate",
- ).then(
- fn=save_result,
- inputs=[
- audio_output,
- *inputs,
- ],
- api_name="stable_audio_save_inpaint" if inpainting else "stable_audio_save",
- ).then(
- fn=torch_clear_memory,
- )
- # FEATURE - crop the audio to the actual length specified
- # def crop_audio(audio, seconds_start_slider, seconds_total_slider):
- # sr, data = audio
- # seconds_start = seconds_start_slider.value
- # seconds_total = seconds_total_slider.value
- # data = data[int(seconds_start * sr) : int(seconds_total * sr)]
- # return sr, data
- if __name__ == "__main__":
- exec(
- """
- main()
- with gr.Blocks() as interface:
- stable_audio_ui_tab()
- interface.queue()
- interface.launch(
- debug=True,
- )
- """
- )
- # main()
|