123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- import os
- from iso639 import Lang
- import torch
- import gradio as gr
- from tts_webui.decorators.gradio_dict_decorator import dictionarize
- from tts_webui.utils.manage_model_state import manage_model_state
- from tts_webui.utils.list_dir_models import unload_model_button
- from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
- from tts_webui.decorators.decorator_log_generation import decorator_log_generation
- from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
- from tts_webui.decorators.decorator_save_wav import decorator_save_wav
- from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
- from tts_webui.decorators.decorator_add_date import decorator_add_date
- from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
- from tts_webui.decorators.log_function_time import log_function_time
- from tts_webui.extensions_loader.decorator_extensions import (
- decorator_extension_outer,
- decorator_extension_inner,
- )
- from tts_webui.utils.randomize_seed import randomize_seed_ui
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from transformers import VitsTokenizer, VitsModel
- @manage_model_state("mms")
- def preload_models_if_needed(language="eng") -> tuple["VitsModel", "VitsTokenizer"]:
- from transformers import VitsTokenizer, VitsModel
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model = VitsModel.from_pretrained( # type: ignore
- f"facebook/mms-tts-{language}",
- )
- model = model.to(device) # type: ignore
- tokenizer = VitsTokenizer.from_pretrained( # type: ignore
- f"facebook/mms-tts-{language}",
- ) # type: ignore
- return model, tokenizer
- @decorator_extension_outer
- @decorator_apply_torch_seed
- @decorator_save_metadata
- @decorator_save_wav
- @decorator_add_model_type("mms")
- @decorator_add_base_filename
- @decorator_add_date
- @decorator_log_generation
- @decorator_extension_inner
- @log_function_time
- def generate_audio_with_mms(
- text,
- language="eng",
- speaking_rate=1.0,
- noise_scale=0.667,
- noise_scale_duration=0.8,
- **kwargs,
- ):
- model, tokenizer = preload_models_if_needed(language)
- model.speaking_rate = speaking_rate
- model.noise_scale = noise_scale
- model.noise_scale_duration = noise_scale_duration
- inputs = tokenizer(text=text, return_tensors="pt").to(model.device)
- with torch.no_grad():
- outputs = model(**inputs) # type: ignore
- waveform = outputs.waveform[0].cpu().numpy().squeeze()
- return {
- "audio_out": (model.config.sampling_rate, waveform),
- }
- def get_mms_languages():
- with open(os.path.join("./tts_webui/mms/", "mms-languages-iso639-3.txt")) as f:
- for line in f:
- yield (Lang(line[:3]).name + line[3:].strip(), line[:3])
- def mms_ui():
- gr.Markdown(
- """
- # MMS
- To use it, simply enter your text, and click "Generate".
- The model will generate speech from the text.
- It uses the [MMS](https://huggingface.co/facebook/mms-tts) model from HuggingFace.
- The MMS-TTS checkpoints are trained on lower-cased, un-punctuated text. By default, the VitsTokenizer normalizes the inputs by removing any casing and punctuation, to avoid passing out-of-vocabulary characters to the model. Hence, the model is agnostic to casing and punctuation, so these should be avoided in the text prompt.
- For certain languages with non-Roman alphabets, such as Arabic, Mandarin or Hindi, the uroman perl package is required to pre-process the text inputs to the Roman alphabet.
- Speaking rate. Larger values give faster synthesised speech.
- Noise scale. How random the speech prediction is. Larger values create more variation in the predicted speech.
-
- Noise scale duration. How random the duration prediction is. Larger values create more variation in the predicted durations.
- """
- )
- with gr.Row():
- with gr.Column():
- mms_input = gr.Textbox(lines=2, label="Input Text")
- mms_generate_button = gr.Button("Generate")
- with gr.Column():
- mms_language = gr.Dropdown(
- choices=list(get_mms_languages()),
- label="Language",
- value="eng",
- )
- speaking_rate = gr.Slider(
- minimum=0.1,
- maximum=10.0,
- step=0.1,
- label="Speaking Rate",
- value=1.0,
- )
- noise_scale = gr.Slider(
- minimum=-2.5,
- maximum=2.5,
- step=0.05,
- label="Noise Scale",
- value=0.667,
- )
- noise_scale_duration = gr.Slider(
- minimum=-1.0,
- maximum=2,
- step=0.05,
- label="Noise Scale Duration",
- value=0.8,
- )
- with gr.Row():
- seed, randomize_seed_callback = randomize_seed_ui()
- unload_model_button("mms")
- audio_out = gr.Audio(label="Output Audio")
- input_dict = {
- mms_input: "text",
- mms_language: "language",
- speaking_rate: "speaking_rate",
- noise_scale: "noise_scale",
- noise_scale_duration: "noise_scale_duration",
- seed: "seed",
- }
- output_dict = {
- "audio_out": audio_out,
- "metadata": gr.JSON(visible=False),
- "folder_root": gr.Textbox(visible=False),
- }
- mms_generate_button.click(
- **randomize_seed_callback,
- ).then(
- **dictionarize(
- fn=generate_audio_with_mms,
- inputs=input_dict,
- outputs=output_dict,
- ),
- api_name="mms",
- )
- def mms_tab():
- with gr.Tab(label="MMS"):
- mms_ui()
- if __name__ == "__main__":
- if "demo" in locals():
- locals()["demo"].close()
- with gr.Blocks() as demo:
- mms_tab()
- demo.launch(
- server_port=7770,
- )
|