2
0

seamless_tab.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import gradio as gr
  2. import torch
  3. import torchaudio
  4. from tts_webui.seamlessM4T.language_code_to_name import (
  5. text_source_languages,
  6. speech_target_languages,
  7. text_source_codes,
  8. speech_target_codes,
  9. )
  10. from tts_webui.decorators.gradio_dict_decorator import gradio_dict_decorator
  11. from tts_webui.utils.randomize_seed import randomize_seed_ui
  12. from tts_webui.utils.manage_model_state import manage_model_state
  13. from tts_webui.utils.list_dir_models import unload_model_button
  14. from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
  15. from tts_webui.decorators.decorator_log_generation import decorator_log_generation
  16. from tts_webui.decorators.decorator_save_metadata import decorator_save_metadata
  17. from tts_webui.decorators.decorator_save_wav import decorator_save_wav
  18. from tts_webui.decorators.decorator_add_base_filename import decorator_add_base_filename
  19. from tts_webui.decorators.decorator_add_date import decorator_add_date
  20. from tts_webui.decorators.decorator_add_model_type import decorator_add_model_type
  21. from tts_webui.decorators.log_function_time import log_function_time
  22. from tts_webui.extensions_loader.decorator_extensions import (
  23. decorator_extension_outer,
  24. decorator_extension_inner,
  25. )
  26. @manage_model_state("seamless")
  27. def get_model(model_name=""):
  28. from transformers import AutoProcessor, SeamlessM4Tv2Model
  29. # todo - add device setting
  30. return SeamlessM4Tv2Model.from_pretrained(
  31. model_name
  32. ), AutoProcessor.from_pretrained(model_name)
  33. @decorator_extension_outer
  34. @decorator_apply_torch_seed
  35. @decorator_save_metadata
  36. @decorator_save_wav
  37. @decorator_add_model_type("seamless")
  38. @decorator_add_base_filename
  39. @decorator_add_date
  40. @decorator_log_generation
  41. @decorator_extension_inner
  42. @log_function_time
  43. def seamless_translate(text, src_lang_name, tgt_lang_name, **kwargs):
  44. model, processor = get_model("facebook/seamless-m4t-v2-large")
  45. src_lang = text_source_codes[text_source_languages.index(src_lang_name)]
  46. tgt_lang = speech_target_codes[speech_target_languages.index(tgt_lang_name)]
  47. text_inputs = processor(text=text, src_lang=src_lang, return_tensors="pt")
  48. audio_array_from_text = (
  49. model.generate(**text_inputs, tgt_lang=tgt_lang)[0].cpu().squeeze()
  50. )
  51. sample_rate = model.config.sampling_rate
  52. return {"audio_out": (sample_rate, audio_array_from_text.numpy())}
  53. @decorator_extension_outer
  54. @decorator_apply_torch_seed
  55. @decorator_save_metadata
  56. @decorator_save_wav
  57. @decorator_add_model_type("seamless")
  58. @decorator_add_base_filename
  59. @decorator_add_date
  60. @decorator_log_generation
  61. @decorator_extension_inner
  62. @log_function_time
  63. def seamless_translate_audio(audio, tgt_lang_name):
  64. model, processor = get_model("facebook/seamless-m4t-v2-large")
  65. # audio, orig_freq = torchaudio.load(audio)
  66. orig_freq, audio = audio
  67. sample_rate = model.config.sampling_rate
  68. audio = torchaudio.functional.resample(
  69. torch.from_numpy(audio).float(), orig_freq=orig_freq, new_freq=16_000
  70. ) # must be a 16 kHz waveform array
  71. tgt_lang = speech_target_codes[speech_target_languages.index(tgt_lang_name)]
  72. audio_inputs = processor(audios=audio, return_tensors="pt")
  73. audio_array_from_audio = (
  74. model.generate(**audio_inputs, tgt_lang=tgt_lang)[0].cpu().squeeze()
  75. )
  76. return {"audio_out": (sample_rate, audio_array_from_audio.numpy())}
  77. def seamless_ui():
  78. gr.Markdown(
  79. """
  80. # Seamless Demo
  81. To use it, simply enter your text, and click "Translate".
  82. The model will translate the text into the target language, and then synthesize the translated text into speech.
  83. It uses the [SeamlessM4Tv2Model](https://huggingface.co/facebook/seamless-m4t-v2-large) model from HuggingFace.
  84. """
  85. )
  86. with gr.Row(equal_height=False):
  87. with gr.Column():
  88. with gr.Tab(label="Text to Speech"):
  89. seamless_input = gr.Textbox(lines=2, label="Input Text")
  90. source_language = gr.Dropdown(
  91. choices=text_source_languages, # type: ignore
  92. label="Source Language",
  93. value="English",
  94. type="value",
  95. )
  96. target_language = gr.Dropdown(
  97. choices=speech_target_languages, # type: ignore
  98. label="Target Language",
  99. value="Mandarin Chinese",
  100. type="value",
  101. )
  102. button = gr.Button("Translate Text to Speech")
  103. with gr.Tab(label="Audio to Speech"):
  104. input_audio = gr.Audio(
  105. sources="upload",
  106. type="numpy",
  107. label="Input Audio",
  108. )
  109. target_language_audio = gr.Dropdown(
  110. choices=speech_target_languages, # type: ignore
  111. label="Target Language (Not all are supported)",
  112. value="Mandarin Chinese",
  113. type="value",
  114. )
  115. button2 = gr.Button("Translate Audio to Speech")
  116. with gr.Column():
  117. audio_out = gr.Audio(label="Output Audio")
  118. seed, randomize_seed_callback = randomize_seed_ui()
  119. unload_model_button("seamless")
  120. input_dict = {
  121. seamless_input: "text",
  122. source_language: "src_lang_name",
  123. target_language: "tgt_lang_name",
  124. seed: "seed",
  125. }
  126. input_dict2 = {
  127. input_audio: "audio",
  128. target_language_audio: "tgt_lang_name",
  129. }
  130. output_dict = {
  131. "audio_out": audio_out,
  132. }
  133. button.click(
  134. **randomize_seed_callback,
  135. ).then(
  136. fn=gradio_dict_decorator(
  137. fn=seamless_translate,
  138. gradio_fn_input_dictionary=input_dict,
  139. outputs=output_dict,
  140. ),
  141. inputs={*input_dict},
  142. outputs=list(output_dict.values()),
  143. api_name="seamless",
  144. )
  145. button2.click(
  146. **randomize_seed_callback,
  147. ).then(
  148. fn=gradio_dict_decorator(
  149. fn=seamless_translate_audio,
  150. gradio_fn_input_dictionary=input_dict2,
  151. outputs=output_dict,
  152. ),
  153. inputs={*input_dict2},
  154. outputs=list(output_dict.values()),
  155. api_name="seamless_audio",
  156. )
  157. def seamless_tab():
  158. with gr.Tab("Seamless M4Tv2", id="seamless"):
  159. seamless_ui()
  160. if __name__ == "__main__":
  161. if "demo" in locals():
  162. demo.close() # type: ignore
  163. with gr.Blocks() as demo:
  164. seamless_tab()
  165. demo.launch()