mms_tab.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import os
  2. from iso639 import Lang
  3. import torch
  4. import gradio as gr
  5. from tts_webui.decorators.gradio_dict_decorator import dictionarize
  6. from tts_webui.utils.manage_model_state import manage_model_state
  7. from tts_webui.utils.list_dir_models import unload_model_button
  8. from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
  9. from tts_webui.decorators.decorator_log_generation import decorator_log_generation
  10. from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
  11. from tts_webui.decorators.decorator_save_wav import decorator_save_wav
  12. from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
  13. from tts_webui.decorators.decorator_add_date import decorator_add_date
  14. from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
  15. from tts_webui.decorators.log_function_time import log_function_time
  16. from tts_webui.extensions_loader.decorator_extensions import (
  17. decorator_extension_outer,
  18. decorator_extension_inner,
  19. )
  20. from tts_webui.utils.randomize_seed import randomize_seed_ui
  21. from typing import TYPE_CHECKING
  22. if TYPE_CHECKING:
  23. from transformers import VitsTokenizer, VitsModel
  24. @manage_model_state("mms")
  25. def preload_models_if_needed(language="eng") -> tuple["VitsModel", "VitsTokenizer"]:
  26. from transformers import VitsTokenizer, VitsModel
  27. device = "cuda" if torch.cuda.is_available() else "cpu"
  28. model = VitsModel.from_pretrained( # type: ignore
  29. f"facebook/mms-tts-{language}",
  30. )
  31. model = model.to(device) # type: ignore
  32. tokenizer = VitsTokenizer.from_pretrained( # type: ignore
  33. f"facebook/mms-tts-{language}",
  34. ) # type: ignore
  35. return model, tokenizer
  36. @decorator_extension_outer
  37. @decorator_apply_torch_seed
  38. @decorator_save_metadata
  39. @decorator_save_wav
  40. @decorator_add_model_type("mms")
  41. @decorator_add_base_filename
  42. @decorator_add_date
  43. @decorator_log_generation
  44. @decorator_extension_inner
  45. @log_function_time
  46. def generate_audio_with_mms(
  47. text,
  48. language="eng",
  49. speaking_rate=1.0,
  50. noise_scale=0.667,
  51. noise_scale_duration=0.8,
  52. **kwargs,
  53. ):
  54. model, tokenizer = preload_models_if_needed(language)
  55. model.speaking_rate = speaking_rate
  56. model.noise_scale = noise_scale
  57. model.noise_scale_duration = noise_scale_duration
  58. inputs = tokenizer(text=text, return_tensors="pt").to(model.device)
  59. with torch.no_grad():
  60. outputs = model(**inputs) # type: ignore
  61. waveform = outputs.waveform[0].cpu().numpy().squeeze()
  62. return {
  63. "audio_out": (model.config.sampling_rate, waveform),
  64. }
  65. def get_mms_languages():
  66. with open(os.path.join("./tts_webui/mms/", "mms-languages-iso639-3.txt")) as f:
  67. for line in f:
  68. yield (Lang(line[:3]).name + line[3:].strip(), line[:3])
  69. def mms_ui():
  70. gr.Markdown(
  71. """
  72. # MMS
  73. To use it, simply enter your text, and click "Generate".
  74. The model will generate speech from the text.
  75. It uses the [MMS](https://huggingface.co/facebook/mms-tts) model from HuggingFace.
  76. The MMS-TTS checkpoints are trained on lower-cased, un-punctuated text. By default, the VitsTokenizer normalizes the inputs by removing any casing and punctuation, to avoid passing out-of-vocabulary characters to the model. Hence, the model is agnostic to casing and punctuation, so these should be avoided in the text prompt.
  77. For certain languages with non-Roman alphabets, such as Arabic, Mandarin or Hindi, the uroman perl package is required to pre-process the text inputs to the Roman alphabet.
  78. Speaking rate. Larger values give faster synthesised speech.
  79. Noise scale. How random the speech prediction is. Larger values create more variation in the predicted speech.
  80. Noise scale duration. How random the duration prediction is. Larger values create more variation in the predicted durations.
  81. """
  82. )
  83. with gr.Row():
  84. with gr.Column():
  85. mms_input = gr.Textbox(lines=2, label="Input Text")
  86. mms_generate_button = gr.Button("Generate")
  87. with gr.Column():
  88. mms_language = gr.Dropdown(
  89. choices=list(get_mms_languages()),
  90. label="Language",
  91. value="eng",
  92. )
  93. speaking_rate = gr.Slider(
  94. minimum=0.1,
  95. maximum=10.0,
  96. step=0.1,
  97. label="Speaking Rate",
  98. value=1.0,
  99. )
  100. noise_scale = gr.Slider(
  101. minimum=-2.5,
  102. maximum=2.5,
  103. step=0.05,
  104. label="Noise Scale",
  105. value=0.667,
  106. )
  107. noise_scale_duration = gr.Slider(
  108. minimum=-1.0,
  109. maximum=2,
  110. step=0.05,
  111. label="Noise Scale Duration",
  112. value=0.8,
  113. )
  114. with gr.Row():
  115. seed, randomize_seed_callback = randomize_seed_ui()
  116. unload_model_button("mms")
  117. audio_out = gr.Audio(label="Output Audio")
  118. input_dict = {
  119. mms_input: "text",
  120. mms_language: "language",
  121. speaking_rate: "speaking_rate",
  122. noise_scale: "noise_scale",
  123. noise_scale_duration: "noise_scale_duration",
  124. seed: "seed",
  125. }
  126. output_dict = {
  127. "audio_out": audio_out,
  128. "metadata": gr.JSON(visible=False),
  129. "folder_root": gr.Textbox(visible=False),
  130. }
  131. mms_generate_button.click(
  132. **randomize_seed_callback,
  133. ).then(
  134. **dictionarize(
  135. fn=generate_audio_with_mms,
  136. inputs=input_dict,
  137. outputs=output_dict,
  138. ),
  139. api_name="mms",
  140. )
  141. def mms_tab():
  142. with gr.Tab(label="MMS"):
  143. mms_ui()
  144. if __name__ == "__main__":
  145. if "demo" in locals():
  146. locals()["demo"].close()
  147. with gr.Blocks() as demo:
  148. mms_tab()
  149. demo.launch(
  150. server_port=7770,
  151. )