styletts2_tab.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import gradio as gr
  2. from tts_webui.decorators.gradio_dict_decorator import gradio_dict_decorator
  3. from tts_webui.utils.randomize_seed import randomize_seed_ui
  4. from tts_webui.utils.manage_model_state import manage_model_state
  5. from tts_webui.utils.list_dir_models import unload_model_button
  6. from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
  7. from tts_webui.decorators.decorator_log_generation import decorator_log_generation
  8. from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
  9. from tts_webui.decorators.decorator_save_wav import decorator_save_wav
  10. from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
  11. from tts_webui.decorators.decorator_add_date import decorator_add_date
  12. from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
  13. from tts_webui.decorators.log_function_time import log_function_time
  14. from tts_webui.extensions_loader.decorator_extensions import (
  15. decorator_extension_outer,
  16. decorator_extension_inner,
  17. )
  18. SAMPLE_RATE = 24_000
  19. @manage_model_state("style_tts2")
  20. def get_model(model_name=""):
  21. from styletts2.tts import StyleTTS2
  22. return StyleTTS2(
  23. model_checkpoint_path=None if model_name == "" else model_name,
  24. config_path=None,
  25. )
  26. def preview_phonemization(text):
  27. from nltk.tokenize import word_tokenize
  28. style_tts2_model = get_model("")
  29. text = text.strip()
  30. text = text.replace('"', "")
  31. phonemized_text = style_tts2_model.phoneme_converter.phonemize(text)
  32. ps = word_tokenize(phonemized_text)
  33. phoneme_string = " ".join(ps)
  34. return phoneme_string
  35. @decorator_extension_outer
  36. @decorator_apply_torch_seed
  37. @decorator_save_metadata
  38. @decorator_save_wav
  39. @decorator_add_model_type("style_tts2")
  40. @decorator_add_base_filename
  41. @decorator_add_date
  42. @decorator_log_generation
  43. @decorator_extension_inner
  44. @log_function_time
  45. def generate_audio_styleTTS2(
  46. text,
  47. alpha=0.3,
  48. beta=0.7,
  49. diffusion_steps=5,
  50. embedding_scale=1,
  51. **kwargs,
  52. ):
  53. model = get_model("")
  54. audio_array = model.inference(
  55. text=text,
  56. alpha=alpha,
  57. beta=beta,
  58. diffusion_steps=diffusion_steps,
  59. embedding_scale=embedding_scale,
  60. # target_voice_path=target_voice_path,
  61. # ref_s=None,
  62. # phonemize=True
  63. )
  64. return {"audio_out": (SAMPLE_RATE, audio_array)}
  65. def style_tts2_ui():
  66. gr.Markdown(
  67. """
  68. # StyleTTS2 Demo
  69. To use it, simply enter your text, and click "Generate".
  70. The model will generate audio from the text.
  71. It uses the [StyleTTS2](https://styletts2.github.io/) model via the [Python Package](https://github.com/sidharthrajaram/StyleTTS2).
  72. As a result, the phonemizer is a MIT licensed subsitute.
  73. Parameters:
  74. * text: Input text to turn into speech.
  75. * alpha: Determines timbre of speech, higher means style is more suitable to text than to the target voice.
  76. * beta: Determines prosody of speech, higher means style is more suitable to text than to the target voice.
  77. * diffusion_steps: The more the steps, the more diverse the samples are, with the cost of speed.
  78. * embedding_scale: Higher scale means style is more conditional to the input text and hence more emotional.
  79. """
  80. )
  81. text = gr.Textbox(label="Text", lines=3, placeholder="Enter text here...")
  82. preview_phonemized_text_button = gr.Button("Preview phonemized text")
  83. phonemized_text = gr.Textbox(
  84. label="Phonemized text (what the model will see)", interactive=False
  85. )
  86. preview_phonemized_text_button.click(
  87. fn=preview_phonemization,
  88. inputs=[text],
  89. outputs=[phonemized_text],
  90. api_name="style_tts2_phonemize",
  91. )
  92. with gr.Row():
  93. alpha = gr.Slider(label="Alpha (timbre)", minimum=-0.5, maximum=2.0, value=0.3)
  94. beta = gr.Slider(label="Beta (prosody)", minimum=-1.0, maximum=2.0, value=0.7)
  95. diffusion_steps = gr.Slider(
  96. label="Diffusion Steps (diversity)", minimum=1, maximum=20, value=5, step=1
  97. )
  98. embedding_scale = gr.Slider(
  99. label="Embedding Scale (emotion)", minimum=0.5, maximum=1.5, value=1.0
  100. )
  101. unload_model_button("style_tts2")
  102. with gr.Row():
  103. reset_params_button = gr.Button("Reset params")
  104. reset_params_button.click(
  105. fn=lambda: [
  106. gr.Slider(value=0.3),
  107. gr.Slider(value=0.7),
  108. gr.Slider(value=5),
  109. gr.Slider(value=1.0),
  110. ],
  111. outputs=[
  112. alpha,
  113. beta,
  114. diffusion_steps,
  115. embedding_scale,
  116. ],
  117. )
  118. generate_button = gr.Button("Generate", variant="primary")
  119. audio_out = gr.Audio(label="Generated audio")
  120. seed, randomize_seed_callback = randomize_seed_ui()
  121. input_dict = {
  122. text: "text",
  123. alpha: "alpha",
  124. beta: "beta",
  125. diffusion_steps: "diffusion_steps",
  126. embedding_scale: "embedding_scale",
  127. seed: "seed",
  128. }
  129. output_dict = {
  130. "audio_out": audio_out,
  131. "metadata": gr.JSON(label="Metadata", visible=False),
  132. "folder_root": gr.Textbox(label="Folder root", visible=False),
  133. }
  134. generate_button.click(
  135. **randomize_seed_callback,
  136. ).then(
  137. fn=gradio_dict_decorator(
  138. fn=generate_audio_styleTTS2,
  139. gradio_fn_input_dictionary=input_dict,
  140. outputs=output_dict,
  141. ),
  142. inputs={*input_dict},
  143. outputs=list(output_dict.values()),
  144. api_name="style_tts2_generate",
  145. )
  146. def style_tts2_tab():
  147. with gr.Tab("StyleTTS2"):
  148. style_tts2_ui()
  149. if __name__ == "__main__":
  150. if "demo" in locals():
  151. locals()["demo"].close()
  152. with gr.Blocks() as demo:
  153. style_tts2_tab()
  154. demo.launch()