main.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import gradio as gr
  2. import torch
  3. import os
  4. from typing import TYPE_CHECKING
  5. from tts_webui.utils.manage_model_state import manage_model_state
  6. from tts_webui.utils.list_dir_models import unload_model_button
  7. if TYPE_CHECKING:
  8. from transformers import Pipeline
  9. def extension__tts_generation_webui():
  10. transcribe_ui()
  11. return {
  12. "package_name": "extension_whisper",
  13. "name": "Whisper",
  14. "version": "0.0.2",
  15. "requirements": "git+https://github.com/rsxdalv/extension_whisper@main",
  16. "description": "Whisper allows transcribing audio files.",
  17. "extension_type": "interface",
  18. "extension_class": "tools",
  19. "author": "rsxdalv",
  20. "extension_author": "rsxdalv",
  21. "license": "MIT",
  22. "website": "https://github.com/rsxdalv/extension_whisper",
  23. "extension_website": "https://github.com/rsxdalv/extension_whisper",
  24. "extension_platform_version": "0.0.1",
  25. }
  26. @manage_model_state("whisper")
  27. def get_model(
  28. model_name="openai/whisper-large-v3",
  29. torch_dtype=torch.float16,
  30. device="cuda:0",
  31. compile=False,
  32. ):
  33. from transformers import AutoModelForSpeechSeq2Seq
  34. from transformers import AutoProcessor
  35. model = AutoModelForSpeechSeq2Seq.from_pretrained(
  36. model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True
  37. ).to(device)
  38. if compile:
  39. model.generation_config.cache_implementation = "static"
  40. model.generation_config.max_new_tokens = 256
  41. model.forward = torch.compile(
  42. model.forward, mode="reduce-overhead", fullgraph=True
  43. )
  44. processor = AutoProcessor.from_pretrained(model_name)
  45. return model, processor
  46. local_dir = os.path.join("data", "models", "whisper")
  47. local_cache_dir = os.path.join(local_dir, "cache")
  48. @manage_model_state("whisper-pipe")
  49. def get_pipe(model_name, device="cuda:0") -> "Pipeline":
  50. from transformers import pipeline
  51. torch_dtype = torch.float16
  52. model, processor = get_model(
  53. # model_name, torch_dtype=torch.float16, device=device, compile=False
  54. model_name,
  55. torch_dtype=torch_dtype,
  56. device=device,
  57. compile=False,
  58. )
  59. return pipeline(
  60. "automatic-speech-recognition",
  61. model=model,
  62. tokenizer=processor.tokenizer,
  63. feature_extractor=processor.feature_extractor,
  64. # chunk_length_s=30,
  65. # batch_size=16, # batch size for inference - set based on your device
  66. torch_dtype=torch.float16,
  67. model_kwargs={"cache_dir": local_cache_dir},
  68. device=device,
  69. )
  70. def transcribe(inputs, model_name="openai/whisper-large-v3"):
  71. if inputs is None:
  72. raise gr.Error(
  73. "No audio file submitted! Please record an audio before submitting your request."
  74. )
  75. pipe = get_pipe(model_name)
  76. result = pipe(
  77. inputs,
  78. generate_kwargs=(
  79. {"task": "transcribe"} if model_name == "openai/whisper-large-v3" else {}
  80. ),
  81. return_timestamps=True,
  82. )
  83. return result["text"]
  84. def transcribe_ui():
  85. gr.Markdown(
  86. "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
  87. " checkpoint [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) and 🤗 Transformers to transcribe audio files"
  88. " of arbitrary length."
  89. )
  90. with gr.Row():
  91. with gr.Column():
  92. audio = gr.Audio(label="Audio", type="filepath", sources="upload")
  93. model_dropdown = gr.Dropdown(
  94. choices=[
  95. "openai/whisper-tiny.en",
  96. "openai/whisper-small.en",
  97. "openai/whisper-medium.en",
  98. "openai/whisper-large-v3",
  99. ],
  100. label="Model",
  101. value="openai/whisper-large-v3",
  102. )
  103. with gr.Column():
  104. text = gr.Textbox(label="Transcription", interactive=False)
  105. with gr.Row():
  106. unload_model_button("whisper-pipe")
  107. unload_model_button("whisper")
  108. transcribe_button = gr.Button("Transcribe", variant="primary")
  109. transcribe_button.click(
  110. fn=transcribe,
  111. inputs=[audio, model_dropdown],
  112. outputs=[text],
  113. api_name="whisper_transcribe",
  114. )
  115. if __name__ == "__main__":
  116. if "demo" in locals():
  117. locals()["demo"].close()
  118. with gr.Blocks() as demo:
  119. with gr.Tab("Whisper"):
  120. transcribe_ui()
  121. demo.queue().launch()