funasr_asr.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # -*- coding:utf-8 -*-
  2. import argparse
  3. import os
  4. import traceback
  5. from tqdm import tqdm
  6. from funasr import AutoModel
  7. path_asr = 'tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
  8. path_vad = 'tools/asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch'
  9. path_punc = 'tools/asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch'
  10. path_asr = path_asr if os.path.exists(path_asr) else "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
  11. path_vad = path_vad if os.path.exists(path_vad) else "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
  12. path_punc = path_punc if os.path.exists(path_punc) else "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
  13. model = AutoModel(
  14. model = path_asr,
  15. model_revision = "v2.0.4",
  16. vad_model = path_vad,
  17. vad_model_revision = "v2.0.4",
  18. punc_model = path_punc,
  19. punc_model_revision = "v2.0.4",
  20. )
  21. def only_asr(input_file):
  22. try:
  23. text = model.generate(input=input_file)[0]["text"]
  24. except:
  25. text = ''
  26. print(traceback.format_exc())
  27. return text
  28. def execute_asr(input_folder, output_folder, model_size, language):
  29. input_file_names = os.listdir(input_folder)
  30. input_file_names.sort()
  31. output = []
  32. output_file_name = os.path.basename(input_folder)
  33. for name in tqdm(input_file_names):
  34. try:
  35. text = model.generate(input="%s/%s"%(input_folder, name))[0]["text"]
  36. output.append(f"{input_folder}/{name}|{output_file_name}|{language.upper()}|{text}")
  37. except:
  38. print(traceback.format_exc())
  39. output_folder = output_folder or "output/asr_opt"
  40. os.makedirs(output_folder, exist_ok=True)
  41. output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list')
  42. with open(output_file_path, "w", encoding="utf-8") as f:
  43. f.write("\n".join(output))
  44. print(f"ASR 任务完成->标注文件路径: {output_file_path}\n")
  45. return output_file_path
  46. if __name__ == '__main__':
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument("-i", "--input_folder", type=str, required=True,
  49. help="Path to the folder containing WAV files.")
  50. parser.add_argument("-o", "--output_folder", type=str, required=True,
  51. help="Output folder to store transcriptions.")
  52. parser.add_argument("-s", "--model_size", type=str, default='large',
  53. help="Model Size of FunASR is Large")
  54. parser.add_argument("-l", "--language", type=str, default='zh', choices=['zh'],
  55. help="Language of the audio files.")
  56. parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'],
  57. help="fp16 or fp32")#还没接入
  58. cmd = parser.parse_args()
  59. execute_asr(
  60. input_folder = cmd.input_folder,
  61. output_folder = cmd.output_folder,
  62. model_size = cmd.model_size,
  63. language = cmd.language,
  64. )