Browse Source

Use .to(device) in place of .cuda() to support CPU and specific GPUs

Luke Van Seters 1 year ago
parent
commit
1d917e6d4c
2 changed files with 43 additions and 27 deletions
  1. 8 6
      generators/utils.py
  2. 35 21
      inference_for_demo_video.py

+ 8 - 6
generators/utils.py

@@ -1,8 +1,8 @@
 import argparse
-import cv2
 import json
 import os
 
+import cv2
 import numpy as np
 import torch
 import torchvision
@@ -17,14 +17,15 @@ def obtain_seq_index(index, num_frames, radius):
 
 
 @torch.no_grad()
-def get_netG(checkpoint_path):
-    from generators.face_model import FaceGenerator
+def get_netG(checkpoint_path, device):
     import yaml
 
+    from generators.face_model import FaceGenerator
+
     with open("generators/renderer_conf.yaml", "r") as f:
         renderer_config = yaml.load(f, Loader=yaml.FullLoader)
 
-    renderer = FaceGenerator(**renderer_config).to(torch.cuda.current_device())
+    renderer = FaceGenerator(**renderer_config).to(device)
 
     checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
     renderer.load_state_dict(checkpoint["net_G_ema"], strict=False)
@@ -41,6 +42,7 @@ def render_video(
     exp_path,
     wav_path,
     output_path,
+    device,
     silent=False,
     semantic_radius=13,
     fps=30,
@@ -95,8 +97,8 @@ def render_video(
     target_splited_exps = torch.split(target_exp_concat, split_size, dim=0)
     output_imgs = []
     for win_exp in target_splited_exps:
-        win_exp = win_exp.cuda()
-        cur_src_img = src_img.expand(win_exp.shape[0], -1, -1, -1).cuda()
+        win_exp = win_exp.to(device)
+        cur_src_img = src_img.expand(win_exp.shape[0], -1, -1, -1).to(device)
         output_dict = net_G(cur_src_img, win_exp)
         output_imgs.append(output_dict["fake_image"].cpu().clamp_(-1, 1))
 

+ 35 - 21
inference_for_demo_video.py

@@ -1,31 +1,30 @@
 import argparse
-import torch
 import json
 import os
-
-from scipy.io import loadmat
+import shutil
 import subprocess
 
 import numpy as np
+import torch
 import torchaudio
-import shutil
+from scipy.io import loadmat
+from transformers import Wav2Vec2Processor
+from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
 
+from configs.default import get_cfg_defaults
+from core.networks.diffusion_net import DiffusionNet
+from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
 from core.utils import (
+    crop_src_image,
     get_pose_params,
     get_video_style_clip,
     get_wav2vec_audio_window,
-    crop_src_image,
 )
-from configs.default import get_cfg_defaults
 from generators.utils import get_netG, render_video
-from core.networks.diffusion_net import DiffusionNet
-from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
-from transformers import Wav2Vec2Processor
-from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
 
 
 @torch.no_grad()
-def get_diff_net(cfg):
+def get_diff_net(cfg, device):
     diff_net = DiffusionNet(
         cfg=cfg,
         net=NoisePredictor(cfg),
@@ -36,7 +35,7 @@ def get_diff_net(cfg):
             mode=cfg.DIFFUSION.SCHEDULE.MODE,
         ),
     )
-    checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT)
+    checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT, map_location=device)
     model_state_dict = checkpoint["model_state_dict"]
     diff_net_dict = {
         k[9:]: v for k, v in model_state_dict.items() if k[:9] == "diff_net."
@@ -62,6 +61,7 @@ def inference_one_video(
     pose_path,
     output_path,
     diff_net,
+    device,
     max_audio_len=None,
     sample_method="ddim",
     ddim_num_step=10,
@@ -79,7 +79,7 @@ def inference_one_video(
         win_size=cfg.WIN_SIZE,
     )
 
-    audio_win = torch.tensor(audio_win_array).cuda()
+    audio_win = torch.tensor(audio_win_array).to(device)
     audio = audio_win.unsqueeze(0)
 
     # the second parameter is "" because of bad interface design...
@@ -87,9 +87,9 @@ def inference_one_video(
         style_clip_path, "", style_max_len=256, start_idx=0
     )
 
-    style_clip = style_clip_raw.unsqueeze(0).cuda()
+    style_clip = style_clip_raw.unsqueeze(0).to(device)
     style_pad_mask = (
-        style_pad_mask_raw.unsqueeze(0).cuda()
+        style_pad_mask_raw.unsqueeze(0).to(device)
         if style_pad_mask_raw is not None
         else None
     )
@@ -151,8 +151,19 @@ if __name__ == "__main__":
         type=str,
         default="test",
     )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+    )
     args = parser.parse_args()
 
+    if args.device == "cuda" and not torch.cuda.is_available():
+        print("CUDA is not available, set --device=cpu to use CPU.")
+        exit(1)
+
+    device = torch.device(args.device)
+
     cfg = get_cfg_defaults()
     cfg.CF_GUIDANCE.SCALE = args.cfg_scale
     cfg.freeze()
@@ -169,10 +180,11 @@ if __name__ == "__main__":
     wav2vec_processor = Wav2Vec2Processor.from_pretrained(
         "jonatasgrosman/wav2vec2-large-xlsr-53-english"
     )
+
     wav2vec_model = (
         Wav2Vec2Model.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
         .eval()
-        .cuda()
+        .to(device)
     )
 
     speech_array, sampling_rate = torchaudio.load(wav_16k_path)
@@ -182,9 +194,9 @@ if __name__ == "__main__":
     )
 
     with torch.no_grad():
-        audio_embedding = wav2vec_model(inputs.input_values.cuda(), return_dict=False)[
-            0
-        ]
+        audio_embedding = wav2vec_model(
+            inputs.input_values.to(device), return_dict=False
+        )[0]
 
     audio_feat_path = os.path.join(tmp_dir, f"{args.output_name}_wav2vec.npy")
     np.save(audio_feat_path, audio_embedding[0].cpu().numpy())
@@ -198,7 +210,7 @@ if __name__ == "__main__":
 
     with torch.no_grad():
         # get diff model and load checkpoint
-        diff_net = get_diff_net(cfg).cuda()
+        diff_net = get_diff_net(cfg, device).to(device)
         # generate face motion
         face_motion_path = os.path.join(tmp_dir, f"{args.output_name}_facemotion.npy")
         inference_one_video(
@@ -208,10 +220,11 @@ if __name__ == "__main__":
             args.pose_path,
             face_motion_path,
             diff_net,
+            device,
             max_audio_len=args.max_gen_len,
         )
         # get renderer
-        renderer = get_netG("checkpoints/renderer.pt")
+        renderer = get_netG("checkpoints/renderer.pt", device)
         # render video
         output_video_path = f"output_video/{args.output_name}.mp4"
         render_video(
@@ -220,6 +233,7 @@ if __name__ == "__main__":
             face_motion_path,
             wav_16k_path,
             output_video_path,
+            device,
             fps=25,
             no_move=False,
         )