tab_voice_clone.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import torchaudio
  2. import torch
  3. import gradio as gr
  4. import numpy as np
  5. from encodec.utils import convert_audio
  6. from bark.generation import load_codec_model
  7. from encodec.model import EncodecModel
  8. from tts_webui.bark.history_to_hash import history_to_hash
  9. from tts_webui.bark.npz_tools import save_npz
  10. from tts_webui.bark.FullGeneration import FullGeneration
  11. from tts_webui.utils.date import get_date_string
  12. from tts_webui.bark.get_audio_from_npz import get_audio_from_full_generation
  13. from typing import TYPE_CHECKING
  14. if TYPE_CHECKING:
  15. from bark_hubert_quantizer.pre_kmeans_hubert import CustomHubert
  16. from bark_hubert_quantizer.customtokenizer import CustomTokenizer
  17. hubert_model = None
  18. def _load_hubert_model(device):
  19. from bark_hubert_quantizer.hubert_manager import HuBERTManager
  20. from bark_hubert_quantizer.pre_kmeans_hubert import CustomHubert
  21. hubert_path = HuBERTManager.make_sure_hubert_installed()
  22. global hubert_model
  23. if hubert_model is None:
  24. hubert_model = CustomHubert(
  25. checkpoint_path=hubert_path,
  26. device=device,
  27. )
  28. return hubert_model
  29. def _get_semantic_vectors(hubert_model: "CustomHubert", path_to_wav: str, device):
  30. # This is where you load your wav, with soundfile or torchaudio for example
  31. wav, sr = torchaudio.load(path_to_wav)
  32. if wav.shape[0] == 2: # Stereo to mono if needed
  33. wav = wav.mean(0, keepdim=True)
  34. wav = wav.to(device)
  35. return hubert_model.forward(wav, input_sample_hz=sr)
  36. def get_semantic_vectors(path_to_wav: str, device):
  37. hubert_model = _load_hubert_model(device)
  38. return _get_semantic_vectors(hubert_model, path_to_wav, device)
  39. tokenizer = None
  40. def _load_tokenizer(
  41. model: str = "quantifier_hubert_base_ls960_14.pth",
  42. repo: str = "GitMylo/bark-voice-cloning",
  43. force_reload: bool = False,
  44. device="cpu",
  45. ) -> "CustomTokenizer":
  46. from bark_hubert_quantizer.customtokenizer import CustomTokenizer
  47. from bark_hubert_quantizer.hubert_manager import HuBERTManager
  48. tokenizer_path = HuBERTManager.make_sure_tokenizer_installed(
  49. model=model,
  50. repo=repo,
  51. local_file=model,
  52. )
  53. global tokenizer
  54. if tokenizer is None or force_reload:
  55. tokenizer = CustomTokenizer.load_from_checkpoint(
  56. # "data/models/hubert/tokenizer.pth"
  57. tokenizer_path,
  58. map_location=device,
  59. )
  60. tokenizer.load_state_dict(torch.load(tokenizer_path, map_location=device))
  61. return tokenizer
  62. def get_semantic_tokens(semantic_vectors: torch.Tensor, device):
  63. tokenizer = _load_tokenizer(device=device)
  64. return tokenizer.get_token(semantic_vectors)
  65. def get_semantic_prompt(path_to_wav: str, device):
  66. semantic_vectors = get_semantic_vectors(path_to_wav, device)
  67. return get_semantic_tokens(semantic_vectors, device).cpu().numpy()
  68. def get_prompts(path_to_wav: str, use_gpu: bool):
  69. device = "cuda" if use_gpu else "cpu"
  70. semantic_prompt = get_semantic_prompt(path_to_wav, device)
  71. fine_prompt, coarse_prompt = get_encodec_prompts(path_to_wav, use_gpu)
  72. return FullGeneration(
  73. semantic_prompt=semantic_prompt,
  74. coarse_prompt=coarse_prompt,
  75. fine_prompt=fine_prompt,
  76. )
  77. def get_encodec_prompts(path_to_wav: str, use_gpu=True):
  78. device = "cuda" if use_gpu else "cpu"
  79. model: EncodecModel = load_codec_model(use_gpu=use_gpu)
  80. wav, sr = torchaudio.load(path_to_wav)
  81. wav = convert_audio(wav, sr, model.sample_rate, model.channels)
  82. wav = wav.unsqueeze(0).to(device)
  83. model.to(device)
  84. # Extract discrete codes from EnCodec
  85. with torch.no_grad():
  86. encoded_frames = model.encode(wav)
  87. fine_prompt: np.ndarray = (
  88. torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
  89. .squeeze()
  90. .cpu()
  91. .numpy()
  92. )
  93. coarse_prompt = fine_prompt[:2, :]
  94. return fine_prompt, coarse_prompt
  95. def save_cloned_voice(full_generation: FullGeneration):
  96. voice_name = f"voice_from_audio_{history_to_hash(full_generation)}"
  97. filename = f"voices/{voice_name}.npz"
  98. date = get_date_string()
  99. metadata = generate_cloned_voice_metadata(full_generation, date)
  100. save_npz(filename, full_generation, metadata)
  101. return filename
  102. def generate_cloned_voice_metadata(full_generation, date):
  103. return {
  104. "_version": "0.0.1",
  105. "_hash_version": "0.0.2",
  106. "_type": "bark",
  107. "hash": history_to_hash(full_generation),
  108. "date": date,
  109. }
  110. def tab_voice_clone():
  111. with gr.Tab("Bark Voice Clone"), gr.Row(equal_height=False):
  112. with gr.Column():
  113. gr.Markdown(
  114. """
  115. Unethical use of this technology is prohibited.
  116. This demo is based on https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer repository.
  117. Information from the original repository (https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer?tab=readme-ov-file#voices-cloned-arent-very-convincing-why-are-other-peoples-cloned-voices-better-than-mine)
  118. ## Voices cloned aren't very convincing, why are other people's cloned voices better than mine?
  119. Make sure these things are **NOT** in your voice input: (in no particular order)
  120. * Noise (You can use a noise remover before)
  121. * Music (There are also music remover tools) (Unless you want music in the background)
  122. * A cut-off at the end (This will cause it to try and continue on the generation)
  123. * Under 1 second of training data (i personally suggest around 10 seconds for good potential, but i've had great results with 5 seconds as well.)
  124. What makes for good prompt audio? (in no particular order)
  125. * Clearly spoken
  126. * No weird background noises
  127. * Only one speaker
  128. * Audio which ends after a sentence ends
  129. * Regular/common voice (They usually have more success, it's still capable of cloning complex voices, but not as good at it)
  130. * Around 10 seconds of data
  131. """
  132. )
  133. with gr.Column():
  134. tokenizer_dropdown = gr.Dropdown(
  135. label="Tokenizer",
  136. choices=[
  137. "quantifier_hubert_base_ls960.pth @ GitMylo/bark-voice-cloning",
  138. "quantifier_hubert_base_ls960_14.pth @ GitMylo/bark-voice-cloning",
  139. "quantifier_V1_hubert_base_ls960_23.pth @ GitMylo/bark-voice-cloning",
  140. "polish-HuBERT-quantizer_8_epoch.pth @ Hobis/bark-voice-cloning-polish-HuBERT-quantizer",
  141. "german-HuBERT-quantizer_14_epoch.pth @ CountFloyd/bark-voice-cloning-german-HuBERT-quantizer",
  142. "es_tokenizer.pth @ Lancer1408/bark-es-tokenizer",
  143. "portuguese-HuBERT-quantizer_24_epoch.pth @ MadVoyager/bark-voice-cloning-portuguese-HuBERT-quantizer",
  144. "turkish_model_epoch_14.pth @ egeadam/bark-voice-cloning-turkish-HuBERT-quantizer",
  145. "japanese-HuBERT-quantizer_24_epoch.pth @ junwchina/bark-voice-cloning-japanese-HuBERT-quantizer",
  146. "it_tokenizer.pth @ gpwr/bark-it-tokenizer",
  147. ],
  148. value="quantifier_hubert_base_ls960_14.pth @ GitMylo/bark-voice-cloning",
  149. allow_custom_value=True,
  150. interactive=True,
  151. )
  152. file_input = gr.Audio(
  153. label="Input Audio",
  154. type="filepath",
  155. sources="upload",
  156. interactive=True,
  157. )
  158. with gr.Row():
  159. use_gpu_checkbox = gr.Checkbox(label="Use GPU", value=True)
  160. clear_models_button = gr.Button(
  161. "Clear models",
  162. variant="secondary",
  163. )
  164. def clear_models():
  165. global hubert_model
  166. global tokenizer
  167. hubert_model = None
  168. tokenizer = None
  169. torch.cuda.empty_cache()
  170. return gr.Button(
  171. value="Models cleared",
  172. )
  173. clear_models_button.click(
  174. fn=clear_models,
  175. outputs=[clear_models_button],
  176. )
  177. generate_voice_button = gr.Button(value="Generate Voice", variant="primary")
  178. def load_tokenizer(tokenizer_and_repo: str, use_gpu: bool):
  179. tokenizer, repo = tokenizer_and_repo.split(" @ ")
  180. device = "cuda" if use_gpu else "cpu"
  181. _load_tokenizer(
  182. model=tokenizer,
  183. repo=repo,
  184. force_reload=True,
  185. device=device,
  186. )
  187. return tokenizer_and_repo
  188. tokenizer_dropdown.change(
  189. load_tokenizer,
  190. inputs=[tokenizer_dropdown, use_gpu_checkbox],
  191. outputs=[tokenizer_dropdown],
  192. api_name="bark_voice_tokenizer_load",
  193. )
  194. gr.Markdown("Generated voice:")
  195. voice_file_name = gr.Textbox(
  196. label="Voice file name", value="", interactive=False
  197. )
  198. audio_preview = gr.Audio(label="Encodec audio preview")
  199. gr.Markdown("Use as history button is now only available in React UI")
  200. def generate_voice(wav_file: str, use_gpu: bool):
  201. full_generation = get_prompts(wav_file, use_gpu)
  202. filename = save_cloned_voice(full_generation)
  203. return filename, get_audio_from_full_generation(full_generation)
  204. generate_voice_button.click(
  205. fn=generate_voice,
  206. inputs=[file_input, use_gpu_checkbox],
  207. # inputs=[file_input, use_gpu_checkbox, tokenizer_dropdown],
  208. outputs=[voice_file_name, audio_preview],
  209. preprocess=True,
  210. api_name="bark_voice_generate",
  211. )
  212. if __name__ == "__main__":
  213. with gr.Blocks() as demo:
  214. tab_voice_clone()
  215. demo.launch()