inference.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import torch
  2. from synthesizer import audio
  3. from synthesizer.hparams import hparams
  4. from synthesizer.models.tacotron import Tacotron
  5. from synthesizer.utils.symbols import symbols
  6. from synthesizer.utils.text import text_to_sequence
  7. from vocoder.display import simple_table
  8. from pathlib import Path
  9. from typing import Union, List
  10. import numpy as np
  11. import librosa
  12. class Synthesizer:
  13. sample_rate = hparams.sample_rate
  14. hparams = hparams
  15. def __init__(self, model_fpath: Path, verbose=True):
  16. """
  17. The model isn't instantiated and loaded in memory until needed or until load() is called.
  18. :param model_fpath: path to the trained model file
  19. :param verbose: if False, prints less information when using the model
  20. """
  21. self.model_fpath = model_fpath
  22. self.verbose = verbose
  23. # Check for GPU
  24. if torch.cuda.is_available():
  25. self.device = torch.device("cuda")
  26. else:
  27. self.device = torch.device("cpu")
  28. if self.verbose:
  29. print("Synthesizer using device:", self.device)
  30. # Tacotron model will be instantiated later on first use.
  31. self._model = None
  32. def is_loaded(self):
  33. """
  34. Whether the model is loaded in memory.
  35. """
  36. return self._model is not None
  37. def load(self):
  38. """
  39. Instantiates and loads the model given the weights file that was passed in the constructor.
  40. """
  41. self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
  42. num_chars=len(symbols),
  43. encoder_dims=hparams.tts_encoder_dims,
  44. decoder_dims=hparams.tts_decoder_dims,
  45. n_mels=hparams.num_mels,
  46. fft_bins=hparams.num_mels,
  47. postnet_dims=hparams.tts_postnet_dims,
  48. encoder_K=hparams.tts_encoder_K,
  49. lstm_dims=hparams.tts_lstm_dims,
  50. postnet_K=hparams.tts_postnet_K,
  51. num_highways=hparams.tts_num_highways,
  52. dropout=hparams.tts_dropout,
  53. stop_threshold=hparams.tts_stop_threshold,
  54. speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
  55. self._model.load(self.model_fpath)
  56. self._model.eval()
  57. if self.verbose:
  58. print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
  59. def synthesize_spectrograms(self, texts: List[str],
  60. embeddings: Union[np.ndarray, List[np.ndarray]],
  61. return_alignments=False):
  62. """
  63. Synthesizes mel spectrograms from texts and speaker embeddings.
  64. :param texts: a list of N text prompts to be synthesized
  65. :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
  66. :param return_alignments: if True, a matrix representing the alignments between the
  67. characters
  68. and each decoder output step will be returned for each spectrogram
  69. :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
  70. sequence length of spectrogram i, and possibly the alignments.
  71. """
  72. # Load the model on the first request.
  73. if not self.is_loaded():
  74. self.load()
  75. # Preprocess text inputs
  76. inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
  77. if not isinstance(embeddings, list):
  78. embeddings = [embeddings]
  79. # Batch inputs
  80. batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
  81. for i in range(0, len(inputs), hparams.synthesis_batch_size)]
  82. batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
  83. for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
  84. specs = []
  85. for i, batch in enumerate(batched_inputs, 1):
  86. if self.verbose:
  87. print(f"\n| Generating {i}/{len(batched_inputs)}")
  88. # Pad texts so they are all the same length
  89. text_lens = [len(text) for text in batch]
  90. max_text_len = max(text_lens)
  91. chars = [pad1d(text, max_text_len) for text in batch]
  92. chars = np.stack(chars)
  93. # Stack speaker embeddings into 2D array for batch processing
  94. speaker_embeds = np.stack(batched_embeds[i-1])
  95. # Convert to tensor
  96. chars = torch.tensor(chars).long().to(self.device)
  97. speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
  98. # Inference
  99. _, mels, alignments = self._model.generate(chars, speaker_embeddings)
  100. mels = mels.detach().cpu().numpy()
  101. for m in mels:
  102. # Trim silence from end of each spectrogram
  103. while np.max(m[:, -1]) < hparams.tts_stop_threshold:
  104. m = m[:, :-1]
  105. specs.append(m)
  106. if self.verbose:
  107. print("\n\nDone.\n")
  108. return (specs, alignments) if return_alignments else specs
  109. @staticmethod
  110. def load_preprocess_wav(fpath):
  111. """
  112. Loads and preprocesses an audio file under the same conditions the audio files were used to
  113. train the synthesizer.
  114. """
  115. wav = librosa.load(str(fpath), hparams.sample_rate)[0]
  116. if hparams.rescale:
  117. wav = wav / np.abs(wav).max() * hparams.rescaling_max
  118. return wav
  119. @staticmethod
  120. def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
  121. """
  122. Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
  123. were fed to the synthesizer when training.
  124. """
  125. if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
  126. wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
  127. else:
  128. wav = fpath_or_wav
  129. mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
  130. return mel_spectrogram
  131. @staticmethod
  132. def griffin_lim(mel):
  133. """
  134. Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
  135. with the same parameters present in hparams.py.
  136. """
  137. return audio.inv_mel_spectrogram(mel, hparams)
  138. def pad1d(x, max_len, pad_value=0):
  139. return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)