vall_e_x_tab.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import gradio as gr
  2. from tts_webui.decorators.gradio_dict_decorator import dictionarize
  3. from tts_webui.utils.randomize_seed import randomize_seed_ui
  4. from tts_webui.utils.manage_model_state import manage_model_state
  5. from tts_webui.utils.list_dir_models import unload_model_button
  6. from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
  7. from tts_webui.decorators.decorator_log_generation import decorator_log_generation
  8. from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
  9. from tts_webui.decorators.decorator_save_wav import decorator_save_wav
  10. from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
  11. from tts_webui.decorators.decorator_add_date import decorator_add_date
  12. from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
  13. from tts_webui.decorators.log_function_time import log_function_time
  14. from tts_webui.extensions_loader.decorator_extensions import (
  15. decorator_extension_outer,
  16. decorator_extension_inner,
  17. )
  18. def preprocess_text(text, language="auto"):
  19. from valle_x.utils.generation import (
  20. text_tokenizer,
  21. lang2token,
  22. langid,
  23. )
  24. language = get_lang(language)
  25. text = text.replace("\n", "").strip(" ")
  26. # detect language
  27. if language == "auto":
  28. language = langid.classify(text)[0]
  29. lang_token = lang2token[language]
  30. text = lang_token + text + lang_token
  31. return str(text_tokenizer.tokenize(text=f"_{text}".strip()))
  32. @manage_model_state("valle_x")
  33. def preload_models_if_needed(checkpoints_dir):
  34. from valle_x.utils.generation import preload_models
  35. preload_models(checkpoints_dir=checkpoints_dir)
  36. return "Loaded" # workaround because preload_models returns None
  37. def get_lang(language):
  38. from valle_x.utils.generation import langdropdown2token, token2lang
  39. lang = token2lang[langdropdown2token[language]]
  40. return lang if lang != "mix" else "auto"
  41. @decorator_extension_outer
  42. @decorator_apply_torch_seed
  43. @decorator_save_metadata
  44. @decorator_save_wav
  45. @decorator_add_model_type("valle_x")
  46. @decorator_add_base_filename
  47. @decorator_add_date
  48. @decorator_log_generation
  49. @decorator_extension_inner
  50. @log_function_time
  51. def generate_audio_gradio(text, prompt, language, accent, mode, **kwargs):
  52. from valle_x.utils.generation import (
  53. SAMPLE_RATE,
  54. generate_audio,
  55. generate_audio_from_long_text,
  56. )
  57. preload_models_if_needed("./data/models/vall-e-x/")
  58. lang = get_lang(language)
  59. prompt = prompt if prompt != "" else None
  60. generate_fn = generate_audio if mode == "short" else generate_audio_from_long_text
  61. audio_array = generate_fn(
  62. text=text,
  63. prompt=prompt,
  64. language=lang,
  65. accent=accent,
  66. **({"mode": mode} if mode != "short" else {}),
  67. )
  68. return {"audio_out": (SAMPLE_RATE, audio_array)}
  69. def split_text_into_sentences(text):
  70. from valle_x.utils.sentence_cutter import split_text_into_sentences
  71. return "###\n".join(split_text_into_sentences(text))
  72. def valle_x_ui_generation():
  73. text = gr.Textbox(label="Text", lines=3, placeholder="Enter text here...")
  74. prompt = gr.Textbox(label="Prompt", visible=False, value="")
  75. with gr.Accordion("Analyze text", open=False):
  76. split_text_into_sentences_button = gr.Button("Preview sentences")
  77. split_text = gr.Textbox(label="Text after split")
  78. split_text_into_sentences_button.click(
  79. fn=split_text_into_sentences,
  80. inputs=[text],
  81. outputs=[split_text],
  82. api_name="vall_e_x_split_text_into_sentences",
  83. )
  84. split_text_into_tokens_button = gr.Button("Preview tokens")
  85. tokens = gr.Textbox(label="Tokens")
  86. gr.Markdown(
  87. """
  88. For longer audio generation, two extension modes are available:
  89. - (Default) short: This will only generate as long as the model's context length.
  90. - fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
  91. - sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
  92. """
  93. )
  94. with gr.Row():
  95. language = gr.Radio(
  96. ["English", "中文", "日本語", "Mix"],
  97. label="Language",
  98. value="Mix",
  99. )
  100. accent = gr.Radio(
  101. ["English", "中文", "日本語", "no-accent"],
  102. label="Accent",
  103. value="no-accent",
  104. )
  105. mode = gr.Radio(
  106. ["short", "fixed-prompt", "sliding-window"],
  107. label="Mode",
  108. value="short",
  109. )
  110. seed, randomize_seed_callback = randomize_seed_ui()
  111. unload_model_button("valle_x")
  112. audio_out = gr.Audio(label="Generated audio")
  113. generate_button = gr.Button("Generate")
  114. split_text_into_tokens_button.click(
  115. fn=preprocess_text,
  116. inputs=[text, language],
  117. outputs=[tokens],
  118. api_name="vall_e_x_tokenize",
  119. )
  120. input_dict = {
  121. text: "text",
  122. prompt: "prompt",
  123. language: "language",
  124. accent: "accent",
  125. mode: "mode",
  126. seed: "seed",
  127. }
  128. output_dict = {
  129. "audio_out": audio_out,
  130. "metadata": gr.JSON(visible=False),
  131. "folder_root": gr.Textbox(visible=False),
  132. }
  133. generate_button.click(
  134. **randomize_seed_callback,
  135. ).then(
  136. **dictionarize(
  137. fn=generate_audio_gradio,
  138. inputs=input_dict,
  139. outputs=output_dict,
  140. ),
  141. api_name="vall_e_x_generate",
  142. )
  143. def valle_x_ui_generation_prompt_making():
  144. from valle_x.utils.prompt_making import transcribe_one, make_prompt, make_transcript
  145. # transcribe_one(model, audio_path)
  146. # make_prompt(name, audio_prompt_path, transcript=None)
  147. # make_transcript(name, wav, sr, transcript=None)
  148. def _valle_x_ui_prompt_making():
  149. with gr.Column():
  150. audio = gr.Audio(label="Audio")
  151. def valle_x_tab():
  152. with gr.Tab("Valle-X", id="valle_x"):
  153. valle_x_ui_generation()
  154. # with gr.Tab("Valle-X Prompt Making Demo", id="valle_x_prompt_making"):
  155. # valle_x_ui_prompt_making()
  156. if __name__ == "__main__":
  157. # only if demo has been defined
  158. if "demo" in locals():
  159. demo.close()
  160. with gr.Blocks() as demo:
  161. valle_x_tab()
  162. demo.launch(
  163. server_port=7770,
  164. )