tortoise_tab.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import gradio as gr
  2. from tts_webui.tortoise.gen_tortoise import (
  3. generate_tortoise_long,
  4. get_voice_list,
  5. TORTOISE_VOICE_DIR_ABS,
  6. )
  7. from tts_webui.tortoise.TortoiseParameters import (
  8. TortoiseParameterComponents,
  9. TortoiseParameters,
  10. )
  11. from tts_webui.tortoise.autoregressive_params import autoregressive_params
  12. from tts_webui.tortoise.diffusion_params import diffusion_params
  13. from tts_webui.tortoise.presets import presets
  14. from tts_webui.tortoise.gr_reload_button import gr_open_button_simple, gr_reload_button
  15. from tts_webui.tortoise.tortoise_model_settings_ui import tortoise_model_settings_ui
  16. from tts_webui.utils.randomize_seed import randomize_seed_ui
  17. MAX_OUTPUTS = 9
  18. def tortoise_tab():
  19. with gr.Tab("Tortoise TTS"):
  20. tortoise_ui()
  21. def tortoise_ui():
  22. with gr.Row():
  23. with gr.Column():
  24. model = tortoise_model_settings_ui()
  25. with gr.Column():
  26. gr.Markdown("Voice")
  27. with gr.Row():
  28. voice = gr.Dropdown(
  29. choices=["Press refresh to load the list"],
  30. value="Press refresh to load the list",
  31. show_label=False,
  32. container=False,
  33. allow_custom_value=True,
  34. )
  35. gr_open_button_simple(
  36. TORTOISE_VOICE_DIR_ABS, api_name="tortoise_open_voices"
  37. )
  38. gr_reload_button().click(
  39. fn=lambda: gr.Dropdown(choices=get_voice_list()),
  40. outputs=[voice],
  41. api_name="tortoise_refresh_voices",
  42. )
  43. with gr.Column():
  44. gr.Markdown("Preset")
  45. preset = gr.Dropdown(
  46. show_label=False,
  47. choices=[
  48. "ultra_fast",
  49. "fast",
  50. "standard",
  51. "high_quality",
  52. ],
  53. value="ultra_fast",
  54. container=False,
  55. )
  56. (
  57. num_autoregressive_samples,
  58. temperature,
  59. length_penalty,
  60. repetition_penalty,
  61. top_p,
  62. max_mel_tokens,
  63. ) = autoregressive_params()
  64. with gr.Column():
  65. cvvp_amount = gr.Slider(
  66. label="CVVP Amount (Deprecated, always 0)",
  67. value=0.0,
  68. minimum=0.0,
  69. maximum=1.0,
  70. step=0.1,
  71. interactive=False,
  72. )
  73. seed, randomize_seed_callback = randomize_seed_ui()
  74. split_prompt = gr.Checkbox(label="Split prompt by lines", value=False)
  75. (
  76. diffusion_iterations,
  77. cond_free,
  78. cond_free_k,
  79. diffusion_temperature,
  80. ) = diffusion_params()
  81. name = gr.Textbox(label="Generation Name", placeholder="Enter name here...")
  82. preset.change(
  83. fn=lambda x: [
  84. gr.Slider(value=presets[x]["num_autoregressive_samples"]),
  85. gr.Slider(value=presets[x]["diffusion_iterations"]),
  86. gr.Checkbox(
  87. value=presets[x]["cond_free"] if "cond_free" in presets[x] else True
  88. ),
  89. ],
  90. inputs=[preset],
  91. outputs=[num_autoregressive_samples, diffusion_iterations, cond_free],
  92. )
  93. text = gr.Textbox(label="Prompt", lines=3, placeholder="Enter text here...")
  94. inputs = list(
  95. TortoiseParameterComponents(
  96. text=text,
  97. voice=voice,
  98. preset=preset,
  99. seed=seed,
  100. cvvp_amount=cvvp_amount,
  101. split_prompt=split_prompt,
  102. num_autoregressive_samples=num_autoregressive_samples,
  103. diffusion_iterations=diffusion_iterations,
  104. temperature=temperature,
  105. length_penalty=length_penalty,
  106. repetition_penalty=repetition_penalty,
  107. top_p=top_p,
  108. max_mel_tokens=max_mel_tokens,
  109. cond_free=cond_free,
  110. cond_free_k=cond_free_k,
  111. diffusion_temperature=diffusion_temperature,
  112. model=model,
  113. name=name,
  114. )
  115. )
  116. with gr.Column():
  117. audio = gr.Audio(type="filepath", label="Generated audio")
  118. folder_root = gr.Textbox(visible=False)
  119. metadata = gr.JSON(visible=False)
  120. with gr.Row():
  121. from tts_webui.history_tab.save_to_favorites import save_to_favorites
  122. gr.Button("Save to favorites").click(
  123. fn=save_to_favorites,
  124. inputs=[folder_root],
  125. )
  126. def generate_button(count):
  127. def gen(*args):
  128. yield from generate_tortoise_long(
  129. count,
  130. TortoiseParameters.from_list(list(args)),
  131. )
  132. return (
  133. gr.Button(
  134. value=f"Generate {count if count > 1 else ''}",
  135. variant="primary" if count == 1 else "secondary",
  136. )
  137. .click(**randomize_seed_callback)
  138. .then(
  139. fn=gen,
  140. inputs=inputs,
  141. outputs=[audio, folder_root, metadata],
  142. api_name=f"generate_tortoise_{count}",
  143. )
  144. )
  145. with gr.Row():
  146. for i in range(MAX_OUTPUTS):
  147. generate_button(MAX_OUTPUTS - i)
  148. if __name__ == "__main__":
  149. if "demo" in locals():
  150. demo.close() # type: ignore
  151. from tts_webui.css.css import full_css
  152. with gr.Blocks(css=full_css) as demo:
  153. tortoise_tab()
  154. demo.launch(
  155. server_port=7770,
  156. )