maha_tts_tab.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import glob
  2. import os
  3. import torch
  4. import gradio as gr
  5. from importlib.metadata import version
  6. from tts_webui.tortoise.gr_reload_button import gr_open_button_simple, gr_reload_button
  7. from tts_webui.utils.list_dir_models import unload_model_button
  8. from tts_webui.utils.randomize_seed import randomize_seed_ui
  9. from tts_webui.utils.manage_model_state import manage_model_state
  10. from tts_webui.decorators.gradio_dict_decorator import dictionarize
  11. from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
  12. from tts_webui.decorators.decorator_log_generation import decorator_log_generation
  13. from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
  14. from tts_webui.decorators.decorator_save_wav import decorator_save_wav
  15. from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
  16. from tts_webui.decorators.decorator_add_date import decorator_add_date
  17. from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
  18. from tts_webui.decorators.log_function_time import log_function_time
  19. from tts_webui.extensions_loader.decorator_extensions import (
  20. decorator_extension_outer,
  21. decorator_extension_inner,
  22. )
  23. MAHA_VERSION = version("maha_tts")
  24. def get_ref_clips(speaker_name):
  25. return glob.glob(os.path.join("./voices-tortoise/", speaker_name, "*.wav"))
  26. def get_voice_list():
  27. files = os.listdir("./voices-tortoise/")
  28. dirs = [f for f in files if os.path.isdir(os.path.join("./voices-tortoise/", f))]
  29. return dirs
  30. @manage_model_state("maha_tts")
  31. def preload_models_if_needed(model_name, device):
  32. from maha_tts.inference import load_models
  33. return load_models(name=model_name, device=device)
  34. @decorator_extension_outer
  35. @decorator_apply_torch_seed
  36. @decorator_save_metadata
  37. @decorator_save_wav
  38. @decorator_add_model_type("maha_tts")
  39. @decorator_add_base_filename
  40. @decorator_add_date
  41. @decorator_log_generation
  42. @decorator_extension_inner
  43. @log_function_time
  44. def generate_audio_maha_tts(
  45. text,
  46. model_name,
  47. text_language,
  48. speaker_name,
  49. device="auto",
  50. **kwargs,
  51. ):
  52. from maha_tts.inference import infer_tts, config
  53. device = torch.device(
  54. device == "auto" and "cuda" if torch.cuda.is_available() else "cpu" or device
  55. )
  56. diff_model, ts_model, vocoder, diffuser = preload_models_if_needed(
  57. model_name=model_name, device=device
  58. )
  59. ref_clips = get_ref_clips(speaker_name)
  60. text_language = (
  61. torch.tensor(config.lang_index[text_language]).to(device).unsqueeze(0)
  62. )
  63. audio, sr = infer_tts(
  64. text, ref_clips, diffuser, diff_model, ts_model, vocoder, text_language
  65. )
  66. return {"audio_out": (sr, audio)}
  67. def maha_tts_ui():
  68. # from maha_tts.config import config
  69. class config:
  70. langs = [
  71. "english",
  72. "tamil",
  73. "telugu",
  74. "punjabi",
  75. "marathi",
  76. "hindi",
  77. "gujarati",
  78. "bengali",
  79. "assamese",
  80. ]
  81. gr.Markdown(
  82. """
  83. # Maha TTS Demo
  84. To use it, simply enter your text, and click "Generate".
  85. The model will generate speech from the text.
  86. It uses the [MahaTTS](https://huggingface.co/Dubverse/MahaTTS) model from HuggingFace.
  87. To make a voice, create a folder with the name of the voice in the `voices-tortoise` folder.
  88. Then, add the voice's wav files to the folder.
  89. A voice must be used. Some voices give errors.
  90. The reference voices can be downloaded [here](https://huggingface.co/Dubverse/MahaTTS/resolve/main/maha_tts/pretrained_models/infer_ref_wavs.zip).
  91. """
  92. )
  93. gr.Markdown(f"MahaTTS version: {MAHA_VERSION}")
  94. text = gr.Textbox(lines=2, label="Input Text")
  95. with gr.Row():
  96. model_name = gr.Radio(
  97. choices=[
  98. ("English", "Smolie-en"),
  99. ("Indian", "Smolie-in"),
  100. ],
  101. label="Model Language",
  102. value="Smolie-in",
  103. type="value",
  104. )
  105. device = gr.Radio(
  106. choices=["auto", "cuda", "cpu"],
  107. label="Device",
  108. value="auto",
  109. type="value",
  110. )
  111. text_language = gr.Radio(
  112. choices=list(config.langs),
  113. label="Text Language",
  114. value="english",
  115. type="value",
  116. )
  117. model_name.change(
  118. fn=lambda choice: choice == "Smolie-en"
  119. and gr.Radio(
  120. value="english",
  121. visible=False,
  122. interactive=False,
  123. )
  124. or gr.Radio(
  125. interactive=True,
  126. visible=True,
  127. ),
  128. inputs=[model_name],
  129. outputs=[text_language],
  130. )
  131. with gr.Column():
  132. gr.Markdown("Speaker Name")
  133. with gr.Row():
  134. voices = get_voice_list()
  135. speaker_name = gr.Dropdown(
  136. choices=voices, # type: ignore
  137. value=voices[0] if voices else "None",
  138. type="value",
  139. show_label=False,
  140. container=False,
  141. )
  142. gr_open_button_simple("voices-tortoise", api_name="maha_tts_open_voices")
  143. gr_reload_button().click(
  144. fn=lambda: gr.Dropdown(choices=get_voice_list()), # type: ignore
  145. outputs=[speaker_name],
  146. api_name="maha_tts_refresh_voices",
  147. )
  148. gr.Markdown("Note: The speaker audio must be mono at this time.")
  149. seed, randomize_seed_callback = randomize_seed_ui()
  150. unload_model_button("maha_tts")
  151. audio_out = gr.Audio(label="Output Audio")
  152. button = gr.Button("Generate")
  153. input_dict = {
  154. text: "text",
  155. model_name: "model_name",
  156. text_language: "text_language",
  157. speaker_name: "speaker_name",
  158. seed: "seed",
  159. device: "device",
  160. }
  161. output_dict = {
  162. "audio_out": audio_out,
  163. "metadata": gr.JSON(label="Metadata", visible=False),
  164. "folder_root": gr.Textbox(label="Folder root", visible=False),
  165. }
  166. button.click(
  167. **randomize_seed_callback,
  168. ).then(
  169. **dictionarize(
  170. fn=generate_audio_maha_tts,
  171. inputs=input_dict,
  172. outputs=output_dict,
  173. ),
  174. api_name="maha_tts",
  175. )
  176. def maha_tts_tab():
  177. with gr.Tab(label="Maha TTS"):
  178. maha_tts_ui()
  179. if __name__ == "__main__":
  180. if "demo" in locals():
  181. demo.close() # type: ignore
  182. with gr.Blocks() as demo:
  183. maha_tts_tab()
  184. demo.launch(
  185. server_port=7770,
  186. )