fasterwhisper_asr.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import argparse
  2. import os
  3. os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
  4. import traceback
  5. import requests
  6. from glob import glob
  7. from faster_whisper import WhisperModel
  8. from tqdm import tqdm
  9. from tools.asr.config import check_fw_local_models
  10. from tools.asr.funasr_asr import only_asr
  11. os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
  12. language_code_list = [
  13. "af", "am", "ar", "as", "az",
  14. "ba", "be", "bg", "bn", "bo",
  15. "br", "bs", "ca", "cs", "cy",
  16. "da", "de", "el", "en", "es",
  17. "et", "eu", "fa", "fi", "fo",
  18. "fr", "gl", "gu", "ha", "haw",
  19. "he", "hi", "hr", "ht", "hu",
  20. "hy", "id", "is", "it", "ja",
  21. "jw", "ka", "kk", "km", "kn",
  22. "ko", "la", "lb", "ln", "lo",
  23. "lt", "lv", "mg", "mi", "mk",
  24. "ml", "mn", "mr", "ms", "mt",
  25. "my", "ne", "nl", "nn", "no",
  26. "oc", "pa", "pl", "ps", "pt",
  27. "ro", "ru", "sa", "sd", "si",
  28. "sk", "sl", "sn", "so", "sq",
  29. "sr", "su", "sv", "sw", "ta",
  30. "te", "tg", "th", "tk", "tl",
  31. "tr", "tt", "uk", "ur", "uz",
  32. "vi", "yi", "yo", "zh", "yue",
  33. "auto"]
  34. def execute_asr(input_folder, output_folder, model_size, language,precision):
  35. if '-local' in model_size:
  36. model_size = model_size[:-6]
  37. model_path = f'tools/asr/models/faster-whisper-{model_size}'
  38. else:
  39. model_path = model_size
  40. if language == 'auto':
  41. language = None #不设置语种由模型自动输出概率最高的语种
  42. print("loading faster whisper model:",model_size,model_path)
  43. try:
  44. model = WhisperModel(model_path, device="cuda", compute_type=precision)
  45. except:
  46. return print(traceback.format_exc())
  47. output = []
  48. output_file_name = os.path.basename(input_folder)
  49. output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list')
  50. if not os.path.exists(output_folder):
  51. os.makedirs(output_folder)
  52. for file in tqdm(glob(os.path.join(input_folder, '**/*.wav'), recursive=True)):
  53. try:
  54. segments, info = model.transcribe(
  55. audio = file,
  56. beam_size = 5,
  57. vad_filter = True,
  58. vad_parameters = dict(min_silence_duration_ms=700),
  59. language = language)
  60. text = ''
  61. if info.language == "zh":
  62. print("检测为中文文本,转funasr处理")
  63. text = only_asr(file)
  64. if text == '':
  65. for segment in segments:
  66. text += segment.text
  67. output.append(f"{file}|{output_file_name}|{info.language.upper()}|{text}")
  68. except:
  69. return print(traceback.format_exc())
  70. with open(output_file_path, "w", encoding="utf-8") as f:
  71. f.write("\n".join(output))
  72. print(f"ASR 任务完成->标注文件路径: {output_file_path}\n")
  73. return output_file_path
  74. if __name__ == '__main__':
  75. parser = argparse.ArgumentParser()
  76. parser.add_argument("-i", "--input_folder", type=str, required=True,
  77. help="Path to the folder containing WAV files.")
  78. parser.add_argument("-o", "--output_folder", type=str, required=True,
  79. help="Output folder to store transcriptions.")
  80. parser.add_argument("-s", "--model_size", type=str, default='large-v3',
  81. choices=check_fw_local_models(),
  82. help="Model Size of Faster Whisper")
  83. parser.add_argument("-l", "--language", type=str, default='ja',
  84. choices=language_code_list,
  85. help="Language of the audio files.")
  86. parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'],
  87. help="fp16 or fp32")
  88. cmd = parser.parse_args()
  89. output_file_path = execute_asr(
  90. input_folder = cmd.input_folder,
  91. output_folder = cmd.output_folder,
  92. model_size = cmd.model_size,
  93. language = cmd.language,
  94. precision = cmd.precision,
  95. )