|
@@ -0,0 +1,456 @@
|
|
|
+import os
|
|
|
+import argparse
|
|
|
+from collections import defaultdict
|
|
|
+import logging
|
|
|
+import pickle
|
|
|
+import json
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from torch import nn
|
|
|
+from scipy.io import loadmat
|
|
|
+
|
|
|
+from configs.default import get_cfg_defaults
|
|
|
+import dlib
|
|
|
+import cv2
|
|
|
+
|
|
|
+
|
|
|
+def _reset_parameters(model):
|
|
|
+ for p in model.parameters():
|
|
|
+ if p.dim() > 1:
|
|
|
+ nn.init.xavier_uniform_(p)
|
|
|
+
|
|
|
+
|
|
|
+def get_video_style(video_name, style_type):
|
|
|
+ person_id, direction, emotion, level, *_ = video_name.split("_")
|
|
|
+ if style_type == "id_dir_emo_level":
|
|
|
+ style = "_".join([person_id, direction, emotion, level])
|
|
|
+ elif style_type == "emotion":
|
|
|
+ style = emotion
|
|
|
+ elif style_type == "id":
|
|
|
+ style = person_id
|
|
|
+ else:
|
|
|
+ raise ValueError("Unknown style type")
|
|
|
+
|
|
|
+ return style
|
|
|
+
|
|
|
+
|
|
|
+def get_style_video_lists(video_list, style_type):
|
|
|
+ style2video_list = defaultdict(list)
|
|
|
+ for video in video_list:
|
|
|
+ style = get_video_style(video, style_type)
|
|
|
+ style2video_list[style].append(video)
|
|
|
+
|
|
|
+ return style2video_list
|
|
|
+
|
|
|
+
|
|
|
+def get_face3d_clip(
|
|
|
+ video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
|
|
|
+):
|
|
|
+ """_summary_
|
|
|
+
|
|
|
+ Args:
|
|
|
+ video_name (_type_): _description_
|
|
|
+ video_root_dir (_type_): _description_
|
|
|
+ num_frames (_type_): _description_
|
|
|
+ start_idx (_type_): "random" , middle, int
|
|
|
+ dtype (_type_, optional): _description_. Defaults to torch.float32.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: _description_
|
|
|
+ ValueError: _description_
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ _type_: _description_
|
|
|
+ """
|
|
|
+ video_path = os.path.join(video_root_dir, video_name)
|
|
|
+ if video_path[-3:] == "mat":
|
|
|
+ face3d_all = loadmat(video_path)["coeff"]
|
|
|
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
|
|
+ elif video_path[-3:] == "txt":
|
|
|
+ face3d_exp = np.loadtxt(video_path)
|
|
|
+ else:
|
|
|
+ raise ValueError("Invalid 3DMM file extension")
|
|
|
+
|
|
|
+ length = face3d_exp.shape[0]
|
|
|
+ clip_num_frames = num_frames
|
|
|
+ if start_idx == "random":
|
|
|
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
|
|
+ elif start_idx == "middle":
|
|
|
+ clip_start_idx = (length - clip_num_frames + 1) // 2
|
|
|
+ elif isinstance(start_idx, int):
|
|
|
+ clip_start_idx = start_idx
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Invalid start_idx {start_idx}")
|
|
|
+
|
|
|
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
|
|
+ face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
|
|
|
+
|
|
|
+ return face3d_clip
|
|
|
+
|
|
|
+
|
|
|
+def get_video_style_clip(
|
|
|
+ video_name,
|
|
|
+ video_root_dir,
|
|
|
+ style_max_len,
|
|
|
+ start_idx="random",
|
|
|
+ dtype=torch.float32,
|
|
|
+ return_start_idx=False,
|
|
|
+):
|
|
|
+ video_path = os.path.join(video_root_dir, video_name)
|
|
|
+ if video_path[-3:] == "mat":
|
|
|
+ face3d_all = loadmat(video_path)["coeff"]
|
|
|
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
|
|
+ elif video_path[-3:] == "txt":
|
|
|
+ face3d_exp = np.loadtxt(video_path)
|
|
|
+ else:
|
|
|
+ raise ValueError("Invalid 3DMM file extension")
|
|
|
+
|
|
|
+ face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
|
|
|
+
|
|
|
+ length = face3d_exp.shape[0]
|
|
|
+ if length >= style_max_len:
|
|
|
+ clip_num_frames = style_max_len
|
|
|
+ if start_idx == "random":
|
|
|
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
|
|
+ elif start_idx == "middle":
|
|
|
+ clip_start_idx = (length - clip_num_frames + 1) // 2
|
|
|
+ elif isinstance(start_idx, int):
|
|
|
+ clip_start_idx = start_idx
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Invalid start_idx {start_idx}")
|
|
|
+
|
|
|
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
|
|
+ pad_mask = torch.tensor([False] * style_max_len)
|
|
|
+ else:
|
|
|
+ clip_start_idx = None
|
|
|
+ padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
|
|
|
+ face3d_clip = torch.cat((face3d_exp, padding), dim=0)
|
|
|
+ pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
|
|
|
+
|
|
|
+ if return_start_idx:
|
|
|
+ return face3d_clip, pad_mask, clip_start_idx
|
|
|
+ else:
|
|
|
+ return face3d_clip, pad_mask
|
|
|
+
|
|
|
+
|
|
|
+def get_video_style_clip_from_np(
|
|
|
+ face3d_exp,
|
|
|
+ style_max_len,
|
|
|
+ start_idx="random",
|
|
|
+ dtype=torch.float32,
|
|
|
+ return_start_idx=False,
|
|
|
+):
|
|
|
+ face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
|
|
|
+
|
|
|
+ length = face3d_exp.shape[0]
|
|
|
+ if length >= style_max_len:
|
|
|
+ clip_num_frames = style_max_len
|
|
|
+ if start_idx == "random":
|
|
|
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
|
|
+ elif start_idx == "middle":
|
|
|
+ clip_start_idx = (length - clip_num_frames + 1) // 2
|
|
|
+ elif isinstance(start_idx, int):
|
|
|
+ clip_start_idx = start_idx
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Invalid start_idx {start_idx}")
|
|
|
+
|
|
|
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
|
|
+ pad_mask = torch.tensor([False] * style_max_len)
|
|
|
+ else:
|
|
|
+ clip_start_idx = None
|
|
|
+ padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
|
|
|
+ face3d_clip = torch.cat((face3d_exp, padding), dim=0)
|
|
|
+ pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
|
|
|
+
|
|
|
+ if return_start_idx:
|
|
|
+ return face3d_clip, pad_mask, clip_start_idx
|
|
|
+ else:
|
|
|
+ return face3d_clip, pad_mask
|
|
|
+
|
|
|
+
|
|
|
+def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
|
|
|
+ """
|
|
|
+
|
|
|
+ Args:
|
|
|
+ audio_feat (np.ndarray): (N, 1024)
|
|
|
+ start_idx (_type_): _description_
|
|
|
+ num_frames (_type_): _description_
|
|
|
+ """
|
|
|
+ center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
|
|
|
+ audio_window_list = []
|
|
|
+ padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
|
|
|
+ for center_idx in center_idx_list:
|
|
|
+ cur_audio_window = []
|
|
|
+ for i in range(center_idx - win_size, center_idx + win_size + 1):
|
|
|
+ if i < 0:
|
|
|
+ cur_audio_window.append(padding)
|
|
|
+ elif i >= len(audio_feat):
|
|
|
+ cur_audio_window.append(padding)
|
|
|
+ else:
|
|
|
+ cur_audio_window.append(audio_feat[i])
|
|
|
+ cur_audio_win_array = np.stack(cur_audio_window, axis=0)
|
|
|
+ audio_window_list.append(cur_audio_win_array)
|
|
|
+
|
|
|
+ audio_window_array = np.stack(audio_window_list, axis=0)
|
|
|
+ return audio_window_array
|
|
|
+
|
|
|
+
|
|
|
+def setup_config():
|
|
|
+ parser = argparse.ArgumentParser(description="voice2pose main program")
|
|
|
+ parser.add_argument(
|
|
|
+ "--config_file", default="", metavar="FILE", help="path to config file"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--resume_from", type=str, default=None, help="the checkpoint to resume from"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--test_only", action="store_true", help="perform testing and evaluation only"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--demo_input", type=str, default=None, help="path to input for demo"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--checkpoint", type=str, default=None, help="the checkpoint to test with"
|
|
|
+ )
|
|
|
+ parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
|
|
|
+ parser.add_argument(
|
|
|
+ "opts",
|
|
|
+ help="Modify config options using the command-line",
|
|
|
+ default=None,
|
|
|
+ nargs=argparse.REMAINDER,
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--local_rank",
|
|
|
+ type=int,
|
|
|
+ help="local rank for DistributedDataParallel",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--master_port",
|
|
|
+ type=str,
|
|
|
+ default="12345",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--max_audio_len",
|
|
|
+ type=int,
|
|
|
+ default=450,
|
|
|
+ help="max_audio_len for inference",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--ddim_num_step",
|
|
|
+ type=int,
|
|
|
+ default=10,
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--inference_seed",
|
|
|
+ type=int,
|
|
|
+ default=1,
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--inference_sample_method",
|
|
|
+ type=str,
|
|
|
+ default="ddim",
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ cfg = get_cfg_defaults()
|
|
|
+ cfg.merge_from_file(args.config_file)
|
|
|
+ cfg.merge_from_list(args.opts)
|
|
|
+ cfg.freeze()
|
|
|
+ return args, cfg
|
|
|
+
|
|
|
+
|
|
|
+def setup_logger(base_path, exp_name):
|
|
|
+ rootLogger = logging.getLogger()
|
|
|
+ rootLogger.setLevel(logging.INFO)
|
|
|
+
|
|
|
+ logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
|
|
|
+
|
|
|
+ log_path = "{0}/{1}.log".format(base_path, exp_name)
|
|
|
+ fileHandler = logging.FileHandler(log_path)
|
|
|
+ fileHandler.setFormatter(logFormatter)
|
|
|
+ rootLogger.addHandler(fileHandler)
|
|
|
+
|
|
|
+ consoleHandler = logging.StreamHandler()
|
|
|
+ consoleHandler.setFormatter(logFormatter)
|
|
|
+ rootLogger.addHandler(consoleHandler)
|
|
|
+ rootLogger.handlers[0].setLevel(logging.INFO)
|
|
|
+
|
|
|
+ logging.info("log path: %s" % log_path)
|
|
|
+
|
|
|
+
|
|
|
+def cosine_loss(a, v, y, logloss=nn.BCELoss()):
|
|
|
+ d = nn.functional.cosine_similarity(a, v)
|
|
|
+ loss = logloss(d.unsqueeze(1), y)
|
|
|
+ return loss
|
|
|
+
|
|
|
+
|
|
|
+def get_pose_params(mat_path):
|
|
|
+ """Get pose parameters from mat file
|
|
|
+
|
|
|
+ Args:
|
|
|
+ mat_path (str): path of mat file
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
|
|
|
+ """
|
|
|
+ mat_dict = loadmat(mat_path)
|
|
|
+
|
|
|
+ np_3dmm = mat_dict["coeff"]
|
|
|
+ angles = np_3dmm[:, 224:227]
|
|
|
+ translations = np_3dmm[:, 254:257]
|
|
|
+
|
|
|
+ np_trans_params = mat_dict["transform_params"]
|
|
|
+ crop = np_trans_params[:, -3:]
|
|
|
+
|
|
|
+ pose_params = np.concatenate((angles, translations, crop), axis=1)
|
|
|
+
|
|
|
+ return pose_params
|
|
|
+
|
|
|
+
|
|
|
+def sinusoidal_embedding(timesteps, dim):
|
|
|
+ """
|
|
|
+
|
|
|
+ Args:
|
|
|
+ timesteps (_type_): (B,)
|
|
|
+ dim (_type_): (C_embed)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ _type_: (B, C_embed)
|
|
|
+ """
|
|
|
+ # check input
|
|
|
+ half = dim // 2
|
|
|
+ timesteps = timesteps.float()
|
|
|
+
|
|
|
+ # compute sinusoidal embedding
|
|
|
+ sinusoid = torch.outer(
|
|
|
+ timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
|
|
|
+ )
|
|
|
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
|
|
+ if dim % 2 != 0:
|
|
|
+ x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
|
|
|
+ """
|
|
|
+
|
|
|
+ Args:
|
|
|
+ audio_feat (np.ndarray): (250, 1024)
|
|
|
+ start_idx (_type_): _description_
|
|
|
+ num_frames (_type_): _description_
|
|
|
+ """
|
|
|
+ center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
|
|
|
+ audio_window_list = []
|
|
|
+ padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
|
|
|
+ for center_idx in center_idx_list:
|
|
|
+ cur_audio_window = []
|
|
|
+ for i in range(center_idx - win_size, center_idx + win_size + 1):
|
|
|
+ if i < 0:
|
|
|
+ cur_audio_window.append(padding)
|
|
|
+ elif i >= len(audio_feat):
|
|
|
+ cur_audio_window.append(padding)
|
|
|
+ else:
|
|
|
+ cur_audio_window.append(audio_feat[i])
|
|
|
+ cur_audio_win_array = np.stack(cur_audio_window, axis=0)
|
|
|
+ audio_window_list.append(cur_audio_win_array)
|
|
|
+
|
|
|
+ audio_window_array = np.stack(audio_window_list, axis=0)
|
|
|
+ return audio_window_array
|
|
|
+
|
|
|
+
|
|
|
+def reshape_audio_feat(style_audio_all_raw, stride):
|
|
|
+ """_summary_
|
|
|
+
|
|
|
+ Args:
|
|
|
+ style_audio_all_raw (_type_): (stride * L, C)
|
|
|
+ stride (_type_): int
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ _type_: (L, C * stride)
|
|
|
+ """
|
|
|
+ style_audio_all_raw = style_audio_all_raw[
|
|
|
+ : style_audio_all_raw.shape[0] // stride * stride
|
|
|
+ ]
|
|
|
+ style_audio_all_raw = style_audio_all_raw.reshape(
|
|
|
+ style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
|
|
|
+ )
|
|
|
+ style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
|
|
|
+ return style_audio_all
|
|
|
+
|
|
|
+
|
|
|
+import random
|
|
|
+
|
|
|
+
|
|
|
+def get_derangement_tuple(n):
|
|
|
+ while True:
|
|
|
+ v = [i for i in range(n)]
|
|
|
+ for j in range(n - 1, -1, -1):
|
|
|
+ p = random.randint(0, j)
|
|
|
+ if v[p] == j:
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ v[j], v[p] = v[p], v[j]
|
|
|
+ else:
|
|
|
+ if v[0] != 0:
|
|
|
+ return tuple(v)
|
|
|
+
|
|
|
+
|
|
|
+def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
|
|
|
+ left, top, right, bot = bbox
|
|
|
+ width = right - left
|
|
|
+ height = bot - top
|
|
|
+
|
|
|
+ width_increase = max(
|
|
|
+ increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
|
|
|
+ )
|
|
|
+ height_increase = max(
|
|
|
+ increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
|
|
|
+ )
|
|
|
+
|
|
|
+ left_t = int(left - width_increase * width)
|
|
|
+ top_t = int(top - height_increase * height)
|
|
|
+ right_t = int(right + width_increase * width)
|
|
|
+ bot_t = int(bot + height_increase * height)
|
|
|
+
|
|
|
+ left_oob = -min(0, left_t)
|
|
|
+ right_oob = right - min(right_t, w)
|
|
|
+ top_oob = -min(0, top_t)
|
|
|
+ bot_oob = bot - min(bot_t, h)
|
|
|
+
|
|
|
+ if max(left_oob, right_oob, top_oob, bot_oob) > 0:
|
|
|
+ max_w = max(left_oob, right_oob)
|
|
|
+ max_h = max(top_oob, bot_oob)
|
|
|
+ if max_w > max_h:
|
|
|
+ return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
|
|
|
+ else:
|
|
|
+ return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
|
|
|
+
|
|
|
+ else:
|
|
|
+ return (left_t, top_t, right_t, bot_t)
|
|
|
+
|
|
|
+
|
|
|
+def crop_src_image(src_img, save_img, increase_ratio, detector=None):
|
|
|
+ if detector is None:
|
|
|
+ detector = dlib.get_frontal_face_detector()
|
|
|
+
|
|
|
+ img = cv2.imread(src_img)
|
|
|
+ faces = detector(img, 0)
|
|
|
+ h, width, _ = img.shape
|
|
|
+ if len(faces) > 0:
|
|
|
+ bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
|
|
|
+ l = bbox[3] - bbox[1]
|
|
|
+ bbox[1] = bbox[1] - l * 0.1
|
|
|
+ bbox[3] = bbox[3] - l * 0.1
|
|
|
+ bbox[1] = max(0, bbox[1])
|
|
|
+ bbox[3] = min(h, bbox[3])
|
|
|
+ bbox = compute_aspect_preserved_bbox(
|
|
|
+ tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
|
|
|
+ )
|
|
|
+ img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
|
|
+ img = cv2.resize(img, (256, 256))
|
|
|
+ cv2.imwrite(save_img, img)
|
|
|
+ else:
|
|
|
+ raise ValueError("No face detected in the input image")
|
|
|
+ # img = cv2.resize(img, (256, 256))
|
|
|
+ # cv2.imwrite(save_img, img)
|