stable_audio_tab.py 22 KB


  1. import os
  2. import json
  3. import numpy as np
  4. import torch
  5. import gradio as gr
  6. from huggingface_hub import hf_hub_download
  7. from tts_webui.history_tab.open_folder import open_folder
  8. from tts_webui.utils.get_path_from_root import get_path_from_root
  9. from tts_webui.utils.torch_clear_memory import torch_clear_memory
  10. from tts_webui.utils.prompt_to_title import prompt_to_title
  11. from tts_webui.tortoise.gr_reload_button import gr_open_button_simple, gr_reload_button
  12. LOCAL_DIR_BASE = os.path.join("data", "models", "stable-audio")
  13. LOCAL_DIR_BASE_ABSOLUTE = get_path_from_root(*LOCAL_DIR_BASE.split(os.path.sep))
  14. OUTPUT_DIR = os.path.join("outputs-rvc", "Stable Audio")
  15. def generate_cond_lazy(
  16. prompt,
  17. negative_prompt=None,
  18. seconds_start=0,
  19. seconds_total=30,
  20. cfg_scale=6.0,
  21. steps=250,
  22. preview_every=None,
  23. seed=-1,
  24. sampler_type="dpmpp-3m-sde",
  25. sigma_min=0.03,
  26. sigma_max=1000,
  27. cfg_rescale=0.0,
  28. use_init=False,
  29. init_audio=None,
  30. init_noise_level=1.0,
  31. mask_cropfrom=None,
  32. mask_pastefrom=None,
  33. mask_pasteto=None,
  34. mask_maskstart=None,
  35. mask_maskend=None,
  36. mask_softnessL=None,
  37. mask_softnessR=None,
  38. mask_marination=None,
  39. batch_size=1,
  40. ):
  41. from stable_audio_tools.interface.gradio import generate_cond, model
  42. if model is None:
  43. gr.Error("Model not loaded")
  44. raise Exception("Model not loaded")
  45. return generate_cond(
  46. prompt=prompt,
  47. negative_prompt=negative_prompt,
  48. seconds_start=seconds_start,
  49. seconds_total=seconds_total,
  50. cfg_scale=cfg_scale,
  51. steps=steps,
  52. preview_every=preview_every,
  53. seed=seed,
  54. sampler_type=sampler_type,
  55. sigma_min=sigma_min,
  56. sigma_max=sigma_max,
  57. cfg_rescale=cfg_rescale,
  58. use_init=use_init,
  59. init_audio=init_audio,
  60. init_noise_level=init_noise_level,
  61. mask_cropfrom=mask_cropfrom,
  62. mask_pastefrom=mask_pastefrom,
  63. mask_pasteto=mask_pasteto,
  64. mask_maskstart=mask_maskstart,
  65. mask_maskend=mask_maskend,
  66. mask_softnessL=mask_softnessL,
  67. mask_softnessR=mask_softnessR,
  68. mask_marination=mask_marination,
  69. batch_size=batch_size,
  70. )
  71. def get_local_dir(name):
  72. return os.path.join(LOCAL_DIR_BASE, name.replace("/", "__"))
  73. def get_config_path(name):
  74. return os.path.join(get_local_dir(name), "model_config.json")
  75. def get_ckpt_path(name):
  76. # check if model.safetensors exists, if not, check if model.ckpt exists
  77. safetensor_path = os.path.join(get_local_dir(name), "model.safetensors")
  78. if os.path.exists(safetensor_path):
  79. return safetensor_path
  80. else:
  81. chkpt_path = os.path.join(get_local_dir(name), "model.ckpt")
  82. if os.path.exists(chkpt_path):
  83. return chkpt_path
  84. else:
  85. raise Exception(
  86. f"Neither model.safetensors nor model.ckpt exists for {name}"
  87. )
  88. def download_pretrained_model(name: str, token: str):
  89. local_dir = get_local_dir(name)
  90. model_config_path = hf_hub_download(
  91. name,
  92. filename="model_config.json",
  93. repo_type="model",
  94. local_dir=local_dir,
  95. local_dir_use_symlinks=False,
  96. token=token,
  97. )
  98. # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
  99. try:
  100. print(f"Downloading {name} model.safetensors")
  101. ckpt_path = hf_hub_download(
  102. name,
  103. filename="model.safetensors",
  104. repo_type="model",
  105. local_dir=local_dir,
  106. local_dir_use_symlinks=False,
  107. token=token,
  108. )
  109. except Exception as e:
  110. print(f"Downloading {name} model.ckpt")
  111. ckpt_path = hf_hub_download(
  112. name,
  113. filename="model.ckpt",
  114. repo_type="model",
  115. local_dir=local_dir,
  116. local_dir_use_symlinks=False,
  117. token=token,
  118. )
  119. return model_config_path, ckpt_path
  120. def get_model_list():
  121. try:
  122. return [
  123. x
  124. for x in os.listdir(LOCAL_DIR_BASE)
  125. if os.path.isdir(os.path.join(LOCAL_DIR_BASE, x))
  126. ]
  127. except FileNotFoundError as e:
  128. print(e)
  129. return []
  130. def load_model_config(model_name):
  131. path = get_config_path(model_name)
  132. try:
  133. with open(path) as f:
  134. return json.load(f)
  135. except Exception as e:
  136. print(e)
  137. message = (
  138. f"Model config not found at {path}. Please ensure model_config.json exists."
  139. )
  140. gr.Error(message)
  141. raise Exception(message)
  142. def stable_audio_ui():
  143. default_model_config_path = os.path.join(LOCAL_DIR_BASE, "diffusion_cond.json")
  144. with open(default_model_config_path) as f:
  145. model_config = json.load(f)
  146. pretransform_ckpt_path = None
  147. pretrained_name = None
  148. def load_model_helper(model_name, model_half):
  149. if model_name == None:
  150. return model_name
  151. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  152. from stable_audio_tools.interface.gradio import load_model
  153. _, model_config_new = load_model(
  154. model_config=load_model_config(model_name),
  155. model_ckpt_path=get_ckpt_path(model_name),
  156. pretrained_name=None,
  157. pretransform_ckpt_path=pretransform_ckpt_path,
  158. model_half=model_half,
  159. device=device, # type: ignore
  160. )
  161. model_type = model_config_new["model_type"] # type: ignore
  162. if model_type != "diffusion_cond":
  163. gr.Error("Only diffusion_cond models are supported")
  164. raise Exception("Only diffusion_cond models are supported")
  165. # if model_type == "diffusion_cond":
  166. # ui = create_txt2audio_ui(model_config)
  167. # elif model_type == "diffusion_uncond":
  168. # ui = create_diffusion_uncond_ui(model_config)
  169. # elif model_type == "autoencoder" or model_type == "diffusion_autoencoder":
  170. # ui = create_autoencoder_ui(model_config)
  171. # elif model_type == "diffusion_prior":
  172. # ui = create_diffusion_prior_ui(model_config)
  173. # elif model_type == "lm":
  174. # ui = create_lm_ui(model_config)
  175. return model_name
  176. def model_select_ui():
  177. with gr.Row():
  178. with gr.Column():
  179. with gr.Row():
  180. model_select = gr.Dropdown(
  181. choices=get_model_list(), # type: ignore
  182. label="Model",
  183. value=pretrained_name,
  184. )
  185. gr_open_button_simple(LOCAL_DIR_BASE, api_name="stable_audio_open_models")
  186. gr_reload_button().click(
  187. fn=lambda: gr.Dropdown(choices=get_model_list()),
  188. outputs=[model_select],
  189. api_name="stable_audio_refresh_models",
  190. )
  191. load_model_button = gr.Button(value="Load model")
  192. with gr.Column():
  193. gr.Markdown(
  194. """
  195. Stable Audio requires a manual download of a model.
  196. Please download a model using the download tab or manually place it in the `data/models/stable-audio` folder.
  197. Note: Due to a [bug](https://github.com/Stability-AI/stable-audio-tools/issues/80) when using half precision
  198. the model will fail to generate with "init audio" or during "inpainting".
  199. """
  200. )
  201. half_checkbox = gr.Checkbox(
  202. label="Use half precision when loading the model",
  203. value=True,
  204. )
  205. load_model_button.click(
  206. fn=load_model_helper,
  207. inputs=[model_select, half_checkbox],
  208. outputs=[model_select],
  209. )
  210. model_select_ui()
  211. with gr.Tabs():
  212. with gr.Tab("Generation"):
  213. create_sampling_ui(model_config)
  214. open_dir_btn = gr.Button("Open outputs folder")
  215. open_dir_btn.click(
  216. lambda: open_folder(OUTPUT_DIR),
  217. api_name="stable_audio_open_output_dir",
  218. )
  219. with gr.Tab("Inpainting"):
  220. create_sampling_ui(model_config, inpainting=True)
  221. open_dir_btn = gr.Button("Open outputs folder")
  222. open_dir_btn.click(lambda: open_folder(OUTPUT_DIR))
  223. with gr.Tab("Model Download"):
  224. model_download_ui()
  225. def model_download_ui():
  226. gr.Markdown("""
  227. Models can be found on the [HuggingFace model hub](https://huggingface.co/models?search=stable-audio-open-1.0).
  228. Recommended models:
  229. voices: RoyalCities/Vocal_Textures_Main
  230. piano: RoyalCities/RC_Infinite_Pianos
  231. original: stabilityai/stable-audio-open-1.0
  232. """)
  233. pretrained_name_text = gr.Textbox(
  234. label="HuggingFace repo name, e.g. stabilityai/stable-audio-open-1.0",
  235. value="",
  236. )
  237. token_text = gr.Textbox(
  238. label="HuggingFace Token (Optional, but needed for some non-public models)",
  239. placeholder="hf_nFjKuKLJF...",
  240. value="",
  241. )
  242. download_btn = gr.Button("Download")
  243. download_btn.click(
  244. download_pretrained_model,
  245. inputs=[pretrained_name_text, token_text],
  246. outputs=[pretrained_name_text],
  247. api_name="model_download",
  248. )
  249. gr.Markdown(
  250. "Models can also be downloaded manually and placed within the directory in a folder, for example `data/models/stable-audio/my_model`"
  251. )
  252. open_dir_btn = gr.Button("Open local models dir")
  253. open_dir_btn.click(
  254. lambda: open_folder(LOCAL_DIR_BASE_ABSOLUTE),
  255. api_name="model_open_dir",
  256. )
  257. def stable_audio_tab():
  258. with gr.Tab("Stable Audio"):
  259. stable_audio_ui()
  260. import scipy.io.wavfile as wavfile
  261. from tts_webui.utils.date import get_date_string
  262. def save_result(audio, *generation_args):
  263. date = get_date_string()
  264. generation_args = {
  265. "date": date,
  266. "version": "0.0.1",
  267. "prompt": generation_args[0],
  268. "negative_prompt": generation_args[1],
  269. "seconds_start_slider": generation_args[2],
  270. "seconds_total_slider": generation_args[3],
  271. "cfg_scale_slider": generation_args[4],
  272. "steps_slider": generation_args[5],
  273. "preview_every_slider": generation_args[6],
  274. "seed_textbox": generation_args[7],
  275. "sampler_type_dropdown": generation_args[8],
  276. "sigma_min_slider": generation_args[9],
  277. "sigma_max_slider": generation_args[10],
  278. "cfg_rescale_slider": generation_args[11],
  279. "init_audio_checkbox": generation_args[12],
  280. "init_audio_input": generation_args[13],
  281. "init_noise_level_slider": generation_args[14],
  282. }
  283. print(generation_args)
  284. prompt = generation_args["prompt"]
  285. name = f"{date}_{prompt_to_title(prompt)}"
  286. base_dir = os.path.join(OUTPUT_DIR, name)
  287. os.makedirs(base_dir, exist_ok=True)
  288. sr, data = audio
  289. wavfile.write(os.path.join(base_dir, f"{name}.wav"), sr, data)
  290. with open(os.path.join(base_dir, f"{name}.json"), "w") as outfile:
  291. json.dump(
  292. generation_args,
  293. outfile,
  294. indent=2,
  295. default=lambda o: "<not serializable>",
  296. )
  297. sample_rate = 32000
  298. sample_size = 1920000
  299. def create_sampling_ui(model_config, inpainting=False):
  300. with gr.Row():
  301. with gr.Column(scale=6):
  302. text = gr.Textbox(show_label=False, placeholder="Prompt")
  303. negative_prompt = gr.Textbox(
  304. show_label=False, placeholder="Negative prompt"
  305. )
  306. generate_button = gr.Button("Generate", variant="primary", scale=1)
  307. model_conditioning_config = model_config["model"].get("conditioning", None)
  308. has_seconds_start = False
  309. has_seconds_total = False
  310. if model_conditioning_config is not None:
  311. for conditioning_config in model_conditioning_config["configs"]:
  312. if conditioning_config["id"] == "seconds_start":
  313. has_seconds_start = True
  314. if conditioning_config["id"] == "seconds_total":
  315. has_seconds_total = True
  316. with gr.Row(equal_height=False):
  317. with gr.Column():
  318. with gr.Row(visible=has_seconds_start or has_seconds_total):
  319. # Timing controls
  320. seconds_start_slider = gr.Slider(
  321. minimum=0,
  322. maximum=512,
  323. step=1,
  324. value=0,
  325. label="Seconds start",
  326. visible=has_seconds_start,
  327. )
  328. seconds_total_slider = gr.Slider(
  329. minimum=0,
  330. maximum=512,
  331. step=1,
  332. value=sample_size // sample_rate,
  333. label="Seconds total",
  334. visible=has_seconds_total,
  335. )
  336. with gr.Row():
  337. # Steps slider
  338. steps_slider = gr.Slider(
  339. minimum=1, maximum=500, step=1, value=100, label="Steps"
  340. )
  341. # Preview Every slider
  342. preview_every_slider = gr.Slider(
  343. minimum=0, maximum=100, step=1, value=0, label="Preview Every"
  344. )
  345. # CFG scale
  346. cfg_scale_slider = gr.Slider(
  347. minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale"
  348. )
  349. with gr.Accordion("Sampler params", open=False):
  350. # Seed
  351. seed_textbox = gr.Textbox(label="Seed", value="-1")
  352. CUSTOM_randomize_seed_checkbox = gr.Checkbox(
  353. label="Randomize seed", value=True
  354. )
  355. # Sampler params
  356. with gr.Row():
  357. sampler_type_dropdown = gr.Dropdown(
  358. [
  359. "dpmpp-2m-sde",
  360. "dpmpp-3m-sde",
  361. "k-heun",
  362. "k-lms",
  363. "k-dpmpp-2s-ancestral",
  364. "k-dpm-2",
  365. "k-dpm-fast",
  366. ],
  367. label="Sampler type",
  368. value="dpmpp-3m-sde",
  369. )
  370. sigma_min_slider = gr.Slider(
  371. minimum=0.0,
  372. maximum=2.0,
  373. step=0.01,
  374. value=0.03,
  375. label="Sigma min",
  376. )
  377. sigma_max_slider = gr.Slider(
  378. minimum=0.0,
  379. maximum=1000.0,
  380. step=0.1,
  381. value=500,
  382. label="Sigma max",
  383. )
  384. cfg_rescale_slider = gr.Slider(
  385. minimum=0.0,
  386. maximum=1,
  387. step=0.01,
  388. value=0.0,
  389. label="CFG rescale amount",
  390. )
  391. if inpainting:
  392. # Inpainting Tab
  393. with gr.Accordion("Inpainting", open=False):
  394. sigma_max_slider.maximum = 1000
  395. init_audio_checkbox = gr.Checkbox(label="Do inpainting")
  396. init_audio_input = gr.Audio(label="Init audio")
  397. init_noise_level_slider = gr.Slider(
  398. minimum=0.1,
  399. maximum=100.0,
  400. step=0.1,
  401. value=80,
  402. label="Init audio noise level",
  403. visible=False,
  404. ) # hide this
  405. mask_cropfrom_slider = gr.Slider(
  406. minimum=0.0,
  407. maximum=100.0,
  408. step=0.1,
  409. value=0,
  410. label="Crop From %",
  411. )
  412. mask_pastefrom_slider = gr.Slider(
  413. minimum=0.0,
  414. maximum=100.0,
  415. step=0.1,
  416. value=0,
  417. label="Paste From %",
  418. )
  419. mask_pasteto_slider = gr.Slider(
  420. minimum=0.0,
  421. maximum=100.0,
  422. step=0.1,
  423. value=100,
  424. label="Paste To %",
  425. )
  426. mask_maskstart_slider = gr.Slider(
  427. minimum=0.0,
  428. maximum=100.0,
  429. step=0.1,
  430. value=50,
  431. label="Mask Start %",
  432. )
  433. mask_maskend_slider = gr.Slider(
  434. minimum=0.0,
  435. maximum=100.0,
  436. step=0.1,
  437. value=100,
  438. label="Mask End %",
  439. )
  440. mask_softnessL_slider = gr.Slider(
  441. minimum=0.0,
  442. maximum=100.0,
  443. step=0.1,
  444. value=0,
  445. label="Softmask Left Crossfade Length %",
  446. )
  447. mask_softnessR_slider = gr.Slider(
  448. minimum=0.0,
  449. maximum=100.0,
  450. step=0.1,
  451. value=0,
  452. label="Softmask Right Crossfade Length %",
  453. )
  454. mask_marination_slider = gr.Slider(
  455. minimum=0.0,
  456. maximum=1,
  457. step=0.0001,
  458. value=0,
  459. label="Marination level",
  460. visible=False,
  461. ) # still working on the usefulness of this
  462. inputs = [
  463. text,
  464. negative_prompt,
  465. seconds_start_slider,
  466. seconds_total_slider,
  467. cfg_scale_slider,
  468. steps_slider,
  469. preview_every_slider,
  470. seed_textbox,
  471. sampler_type_dropdown,
  472. sigma_min_slider,
  473. sigma_max_slider,
  474. cfg_rescale_slider,
  475. init_audio_checkbox,
  476. init_audio_input,
  477. init_noise_level_slider,
  478. mask_cropfrom_slider,
  479. mask_pastefrom_slider,
  480. mask_pasteto_slider,
  481. mask_maskstart_slider,
  482. mask_maskend_slider,
  483. mask_softnessL_slider,
  484. mask_softnessR_slider,
  485. mask_marination_slider,
  486. ]
  487. else:
  488. # Default generation tab
  489. with gr.Accordion("Init audio", open=False):
  490. init_audio_checkbox = gr.Checkbox(label="Use init audio")
  491. init_audio_input = gr.Audio(label="Init audio")
  492. init_noise_level_slider = gr.Slider(
  493. minimum=0.1,
  494. maximum=100.0,
  495. step=0.01,
  496. value=0.1,
  497. label="Init noise level",
  498. )
  499. inputs = [
  500. text,
  501. negative_prompt,
  502. seconds_start_slider,
  503. seconds_total_slider,
  504. cfg_scale_slider,
  505. steps_slider,
  506. preview_every_slider,
  507. seed_textbox,
  508. sampler_type_dropdown,
  509. sigma_min_slider,
  510. sigma_max_slider,
  511. cfg_rescale_slider,
  512. init_audio_checkbox,
  513. init_audio_input,
  514. init_noise_level_slider,
  515. ]
  516. with gr.Column():
  517. audio_output = gr.Audio(label="Output audio", interactive=False)
  518. audio_spectrogram_output = gr.Gallery(
  519. label="Output spectrogram", show_label=False
  520. )
  521. send_to_init_button = gr.Button("Send to init audio", scale=1)
  522. send_to_init_button.click(
  523. fn=lambda audio: audio,
  524. inputs=[audio_output],
  525. outputs=[init_audio_input],
  526. )
  527. def randomize_seed(seed, randomize_seed):
  528. if randomize_seed:
  529. return np.random.randint(0, 2**32 - 1, dtype=np.uint32)
  530. else:
  531. return int(seed)
  532. generate_button.click(
  533. fn=randomize_seed,
  534. inputs=[seed_textbox, CUSTOM_randomize_seed_checkbox],
  535. outputs=[seed_textbox],
  536. ).then(
  537. fn=generate_cond_lazy,
  538. inputs=inputs,
  539. outputs=[audio_output, audio_spectrogram_output],
  540. api_name="stable_audio_inpaint" if inpainting else "stable_audio_generate",
  541. ).then(
  542. fn=save_result,
  543. inputs=[
  544. audio_output,
  545. *inputs,
  546. ],
  547. api_name="stable_audio_save_inpaint" if inpainting else "stable_audio_save",
  548. ).then(
  549. fn=torch_clear_memory,
  550. )
  551. # FEATURE - crop the audio to the actual length specified
  552. # def crop_audio(audio, seconds_start_slider, seconds_total_slider):
  553. # sr, data = audio
  554. # seconds_start = seconds_start_slider.value
  555. # seconds_total = seconds_total_slider.value
  556. # data = data[int(seconds_start * sr) : int(seconds_total * sr)]
  557. # return sr, data
  558. if __name__ == "__main__":
  559. exec(
  560. """
  561. main()
  562. with gr.Blocks() as interface:
  563. stable_audio_ui_tab()
  564. interface.queue()
  565. interface.launch(
  566. debug=True,
  567. )
  568. """
  569. )
  570. # main()