|
@@ -1,31 +1,30 @@
|
|
|
import argparse
|
|
|
-import torch
|
|
|
import json
|
|
|
import os
|
|
|
-
|
|
|
-from scipy.io import loadmat
|
|
|
+import shutil
|
|
|
import subprocess
|
|
|
|
|
|
import numpy as np
|
|
|
+import torch
|
|
|
import torchaudio
|
|
|
-import shutil
|
|
|
+from scipy.io import loadmat
|
|
|
+from transformers import Wav2Vec2Processor
|
|
|
+from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
|
|
|
|
|
|
+from configs.default import get_cfg_defaults
|
|
|
+from core.networks.diffusion_net import DiffusionNet
|
|
|
+from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
|
|
|
from core.utils import (
|
|
|
+ crop_src_image,
|
|
|
get_pose_params,
|
|
|
get_video_style_clip,
|
|
|
get_wav2vec_audio_window,
|
|
|
- crop_src_image,
|
|
|
)
|
|
|
-from configs.default import get_cfg_defaults
|
|
|
from generators.utils import get_netG, render_video
|
|
|
-from core.networks.diffusion_net import DiffusionNet
|
|
|
-from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
|
|
|
-from transformers import Wav2Vec2Processor
|
|
|
-from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
-def get_diff_net(cfg):
|
|
|
+def get_diff_net(cfg, device):
|
|
|
diff_net = DiffusionNet(
|
|
|
cfg=cfg,
|
|
|
net=NoisePredictor(cfg),
|
|
@@ -36,7 +35,7 @@ def get_diff_net(cfg):
|
|
|
mode=cfg.DIFFUSION.SCHEDULE.MODE,
|
|
|
),
|
|
|
)
|
|
|
- checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT)
|
|
|
+ checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT, map_location=device)
|
|
|
model_state_dict = checkpoint["model_state_dict"]
|
|
|
diff_net_dict = {
|
|
|
k[9:]: v for k, v in model_state_dict.items() if k[:9] == "diff_net."
|
|
@@ -62,6 +61,7 @@ def inference_one_video(
|
|
|
pose_path,
|
|
|
output_path,
|
|
|
diff_net,
|
|
|
+ device,
|
|
|
max_audio_len=None,
|
|
|
sample_method="ddim",
|
|
|
ddim_num_step=10,
|
|
@@ -79,7 +79,7 @@ def inference_one_video(
|
|
|
win_size=cfg.WIN_SIZE,
|
|
|
)
|
|
|
|
|
|
- audio_win = torch.tensor(audio_win_array).cuda()
|
|
|
+ audio_win = torch.tensor(audio_win_array).to(device)
|
|
|
audio = audio_win.unsqueeze(0)
|
|
|
|
|
|
# the second parameter is "" because of bad interface design...
|
|
@@ -87,9 +87,9 @@ def inference_one_video(
|
|
|
style_clip_path, "", style_max_len=256, start_idx=0
|
|
|
)
|
|
|
|
|
|
- style_clip = style_clip_raw.unsqueeze(0).cuda()
|
|
|
+ style_clip = style_clip_raw.unsqueeze(0).to(device)
|
|
|
style_pad_mask = (
|
|
|
- style_pad_mask_raw.unsqueeze(0).cuda()
|
|
|
+ style_pad_mask_raw.unsqueeze(0).to(device)
|
|
|
if style_pad_mask_raw is not None
|
|
|
else None
|
|
|
)
|
|
@@ -151,8 +151,19 @@ if __name__ == "__main__":
|
|
|
type=str,
|
|
|
default="test",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--device",
|
|
|
+ type=str,
|
|
|
+ default="cuda",
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
+ if args.device == "cuda" and not torch.cuda.is_available():
|
|
|
+ print("CUDA is not available, set --device=cpu to use CPU.")
|
|
|
+ exit(1)
|
|
|
+
|
|
|
+ device = torch.device(args.device)
|
|
|
+
|
|
|
cfg = get_cfg_defaults()
|
|
|
cfg.CF_GUIDANCE.SCALE = args.cfg_scale
|
|
|
cfg.freeze()
|
|
@@ -169,10 +180,11 @@ if __name__ == "__main__":
|
|
|
wav2vec_processor = Wav2Vec2Processor.from_pretrained(
|
|
|
"jonatasgrosman/wav2vec2-large-xlsr-53-english"
|
|
|
)
|
|
|
+
|
|
|
wav2vec_model = (
|
|
|
Wav2Vec2Model.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
|
|
.eval()
|
|
|
- .cuda()
|
|
|
+ .to(device)
|
|
|
)
|
|
|
|
|
|
speech_array, sampling_rate = torchaudio.load(wav_16k_path)
|
|
@@ -182,9 +194,9 @@ if __name__ == "__main__":
|
|
|
)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- audio_embedding = wav2vec_model(inputs.input_values.cuda(), return_dict=False)[
|
|
|
- 0
|
|
|
- ]
|
|
|
+ audio_embedding = wav2vec_model(
|
|
|
+ inputs.input_values.to(device), return_dict=False
|
|
|
+ )[0]
|
|
|
|
|
|
audio_feat_path = os.path.join(tmp_dir, f"{args.output_name}_wav2vec.npy")
|
|
|
np.save(audio_feat_path, audio_embedding[0].cpu().numpy())
|
|
@@ -198,7 +210,7 @@ if __name__ == "__main__":
|
|
|
|
|
|
with torch.no_grad():
|
|
|
# get diff model and load checkpoint
|
|
|
- diff_net = get_diff_net(cfg).cuda()
|
|
|
+ diff_net = get_diff_net(cfg, device).to(device)
|
|
|
# generate face motion
|
|
|
face_motion_path = os.path.join(tmp_dir, f"{args.output_name}_facemotion.npy")
|
|
|
inference_one_video(
|
|
@@ -208,10 +220,11 @@ if __name__ == "__main__":
|
|
|
args.pose_path,
|
|
|
face_motion_path,
|
|
|
diff_net,
|
|
|
+ device,
|
|
|
max_audio_len=args.max_gen_len,
|
|
|
)
|
|
|
# get renderer
|
|
|
- renderer = get_netG("checkpoints/renderer.pt")
|
|
|
+ renderer = get_netG("checkpoints/renderer.pt", device)
|
|
|
# render video
|
|
|
output_video_path = f"output_video/{args.output_name}.mp4"
|
|
|
render_video(
|
|
@@ -220,6 +233,7 @@ if __name__ == "__main__":
|
|
|
face_motion_path,
|
|
|
wav_16k_path,
|
|
|
output_video_path,
|
|
|
+ device,
|
|
|
fps=25,
|
|
|
no_move=False,
|
|
|
)
|