util.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import sys
  2. import asyncio
  3. from io import BytesIO
  4. from fairseq import checkpoint_utils
  5. import torch
  6. import edge_tts
  7. import librosa
  8. # https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/main/config.py#L43-L55 # noqa
  9. def has_mps() -> bool:
  10. if sys.platform != "darwin":
  11. return False
  12. else:
  13. if not getattr(torch, 'has_mps', False):
  14. return False
  15. try:
  16. torch.zeros(1).to(torch.device("mps"))
  17. return True
  18. except Exception:
  19. return False
  20. def is_half(device: str) -> bool:
  21. if not device.startswith('cuda'):
  22. return False
  23. else:
  24. gpu_name = torch.cuda.get_device_name(
  25. int(device.split(':')[-1])
  26. ).upper()
  27. # ...regex?
  28. if (
  29. ('16' in gpu_name and 'V100' not in gpu_name)
  30. or 'P40' in gpu_name
  31. or '1060' in gpu_name
  32. or '1070' in gpu_name
  33. or '1080' in gpu_name
  34. ):
  35. return False
  36. return True
  37. def load_hubert_model(device: str, model_path: str = 'hubert_base.pt'):
  38. model = checkpoint_utils.load_model_ensemble_and_task(
  39. [model_path]
  40. )[0][0].to(device)
  41. if is_half(device):
  42. return model.half()
  43. else:
  44. return model.float()
  45. async def call_edge_tts(speaker_name: str, text: str):
  46. tts_com = edge_tts.Communicate(text, speaker_name)
  47. tts_raw = b''
  48. # Stream TTS audio to bytes
  49. async for chunk in tts_com.stream():
  50. if chunk['type'] == 'audio':
  51. tts_raw += chunk['data']
  52. # Convert mp3 stream to wav
  53. ffmpeg_proc = await asyncio.create_subprocess_exec(
  54. 'ffmpeg',
  55. '-f', 'mp3',
  56. '-i', '-',
  57. '-f', 'wav',
  58. '-',
  59. stdin=asyncio.subprocess.PIPE,
  60. stdout=asyncio.subprocess.PIPE
  61. )
  62. (tts_wav, _) = await ffmpeg_proc.communicate(tts_raw)
  63. return librosa.load(BytesIO(tts_wav))