cnhubert.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import time
  2. import librosa
  3. import torch
  4. import torch.nn.functional as F
  5. import soundfile as sf
  6. import logging
  7. logging.getLogger("numba").setLevel(logging.WARNING)
  8. from transformers import (
  9. Wav2Vec2FeatureExtractor,
  10. HubertModel,
  11. )
  12. import utils
  13. import torch.nn as nn
  14. cnhubert_base_path = None
  15. class CNHubert(nn.Module):
  16. def __init__(self):
  17. super().__init__()
  18. self.model = HubertModel.from_pretrained(cnhubert_base_path)
  19. self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
  20. cnhubert_base_path
  21. )
  22. def forward(self, x):
  23. input_values = self.feature_extractor(
  24. x, return_tensors="pt", sampling_rate=16000
  25. ).input_values.to(x.device)
  26. feats = self.model(input_values)["last_hidden_state"]
  27. return feats
  28. # class CNHubertLarge(nn.Module):
  29. # def __init__(self):
  30. # super().__init__()
  31. # self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
  32. # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
  33. # def forward(self, x):
  34. # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
  35. # feats = self.model(input_values)["last_hidden_state"]
  36. # return feats
  37. #
  38. # class CVec(nn.Module):
  39. # def __init__(self):
  40. # super().__init__()
  41. # self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
  42. # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
  43. # def forward(self, x):
  44. # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
  45. # feats = self.model(input_values)["last_hidden_state"]
  46. # return feats
  47. #
  48. # class cnw2v2base(nn.Module):
  49. # def __init__(self):
  50. # super().__init__()
  51. # self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
  52. # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
  53. # def forward(self, x):
  54. # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
  55. # feats = self.model(input_values)["last_hidden_state"]
  56. # return feats
  57. def get_model():
  58. model = CNHubert()
  59. model.eval()
  60. return model
  61. # def get_large_model():
  62. # model = CNHubertLarge()
  63. # model.eval()
  64. # return model
  65. #
  66. # def get_model_cvec():
  67. # model = CVec()
  68. # model.eval()
  69. # return model
  70. #
  71. # def get_model_cnw2v2base():
  72. # model = cnw2v2base()
  73. # model.eval()
  74. # return model
  75. def get_content(hmodel, wav_16k_tensor):
  76. with torch.no_grad():
  77. feats = hmodel(wav_16k_tensor)
  78. return feats.transpose(1, 2)
  79. if __name__ == "__main__":
  80. model = get_model()
  81. src_path = "/Users/Shared/原音频2.wav"
  82. wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
  83. model = model
  84. wav_16k_tensor = wav_16k_tensor
  85. feats = get_content(model, wav_16k_tensor)
  86. print(feats.shape)