123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 |
- import os
- import argparse
- from collections import defaultdict
- import logging
- import pickle
- import json
- import numpy as np
- import torch
- from torch import nn
- from scipy.io import loadmat
- from configs.default import get_cfg_defaults
- import dlib
- import cv2
- def _reset_parameters(model):
- for p in model.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def get_video_style(video_name, style_type):
- person_id, direction, emotion, level, *_ = video_name.split("_")
- if style_type == "id_dir_emo_level":
- style = "_".join([person_id, direction, emotion, level])
- elif style_type == "emotion":
- style = emotion
- elif style_type == "id":
- style = person_id
- else:
- raise ValueError("Unknown style type")
- return style
- def get_style_video_lists(video_list, style_type):
- style2video_list = defaultdict(list)
- for video in video_list:
- style = get_video_style(video, style_type)
- style2video_list[style].append(video)
- return style2video_list
- def get_face3d_clip(
- video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
- ):
- """_summary_
- Args:
- video_name (_type_): _description_
- video_root_dir (_type_): _description_
- num_frames (_type_): _description_
- start_idx (_type_): "random" , middle, int
- dtype (_type_, optional): _description_. Defaults to torch.float32.
- Raises:
- ValueError: _description_
- ValueError: _description_
- Returns:
- _type_: _description_
- """
- video_path = os.path.join(video_root_dir, video_name)
- if video_path[-3:] == "mat":
- face3d_all = loadmat(video_path)["coeff"]
- face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
- elif video_path[-3:] == "txt":
- face3d_exp = np.loadtxt(video_path)
- else:
- raise ValueError("Invalid 3DMM file extension")
- length = face3d_exp.shape[0]
- clip_num_frames = num_frames
- if start_idx == "random":
- clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
- elif start_idx == "middle":
- clip_start_idx = (length - clip_num_frames + 1) // 2
- elif isinstance(start_idx, int):
- clip_start_idx = start_idx
- else:
- raise ValueError(f"Invalid start_idx {start_idx}")
- face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
- face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
- return face3d_clip
- def get_video_style_clip(
- video_name,
- video_root_dir,
- style_max_len,
- start_idx="random",
- dtype=torch.float32,
- return_start_idx=False,
- ):
- video_path = os.path.join(video_root_dir, video_name)
- if video_path[-3:] == "mat":
- face3d_all = loadmat(video_path)["coeff"]
- face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
- elif video_path[-3:] == "txt":
- face3d_exp = np.loadtxt(video_path)
- else:
- raise ValueError("Invalid 3DMM file extension")
- face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
- length = face3d_exp.shape[0]
- if length >= style_max_len:
- clip_num_frames = style_max_len
- if start_idx == "random":
- clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
- elif start_idx == "middle":
- clip_start_idx = (length - clip_num_frames + 1) // 2
- elif isinstance(start_idx, int):
- clip_start_idx = start_idx
- else:
- raise ValueError(f"Invalid start_idx {start_idx}")
- face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
- pad_mask = torch.tensor([False] * style_max_len)
- else:
- clip_start_idx = None
- padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
- face3d_clip = torch.cat((face3d_exp, padding), dim=0)
- pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
- if return_start_idx:
- return face3d_clip, pad_mask, clip_start_idx
- else:
- return face3d_clip, pad_mask
- def get_video_style_clip_from_np(
- face3d_exp,
- style_max_len,
- start_idx="random",
- dtype=torch.float32,
- return_start_idx=False,
- ):
- face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
- length = face3d_exp.shape[0]
- if length >= style_max_len:
- clip_num_frames = style_max_len
- if start_idx == "random":
- clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
- elif start_idx == "middle":
- clip_start_idx = (length - clip_num_frames + 1) // 2
- elif isinstance(start_idx, int):
- clip_start_idx = start_idx
- else:
- raise ValueError(f"Invalid start_idx {start_idx}")
- face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
- pad_mask = torch.tensor([False] * style_max_len)
- else:
- clip_start_idx = None
- padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
- face3d_clip = torch.cat((face3d_exp, padding), dim=0)
- pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
- if return_start_idx:
- return face3d_clip, pad_mask, clip_start_idx
- else:
- return face3d_clip, pad_mask
- def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
- """
- Args:
- audio_feat (np.ndarray): (N, 1024)
- start_idx (_type_): _description_
- num_frames (_type_): _description_
- """
- center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
- audio_window_list = []
- padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
- for center_idx in center_idx_list:
- cur_audio_window = []
- for i in range(center_idx - win_size, center_idx + win_size + 1):
- if i < 0:
- cur_audio_window.append(padding)
- elif i >= len(audio_feat):
- cur_audio_window.append(padding)
- else:
- cur_audio_window.append(audio_feat[i])
- cur_audio_win_array = np.stack(cur_audio_window, axis=0)
- audio_window_list.append(cur_audio_win_array)
- audio_window_array = np.stack(audio_window_list, axis=0)
- return audio_window_array
- def setup_config():
- parser = argparse.ArgumentParser(description="voice2pose main program")
- parser.add_argument(
- "--config_file", default="", metavar="FILE", help="path to config file"
- )
- parser.add_argument(
- "--resume_from", type=str, default=None, help="the checkpoint to resume from"
- )
- parser.add_argument(
- "--test_only", action="store_true", help="perform testing and evaluation only"
- )
- parser.add_argument(
- "--demo_input", type=str, default=None, help="path to input for demo"
- )
- parser.add_argument(
- "--checkpoint", type=str, default=None, help="the checkpoint to test with"
- )
- parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
- parser.add_argument(
- "opts",
- help="Modify config options using the command-line",
- default=None,
- nargs=argparse.REMAINDER,
- )
- parser.add_argument(
- "--local_rank",
- type=int,
- help="local rank for DistributedDataParallel",
- )
- parser.add_argument(
- "--master_port",
- type=str,
- default="12345",
- )
- parser.add_argument(
- "--max_audio_len",
- type=int,
- default=450,
- help="max_audio_len for inference",
- )
- parser.add_argument(
- "--ddim_num_step",
- type=int,
- default=10,
- )
- parser.add_argument(
- "--inference_seed",
- type=int,
- default=1,
- )
- parser.add_argument(
- "--inference_sample_method",
- type=str,
- default="ddim",
- )
- args = parser.parse_args()
- cfg = get_cfg_defaults()
- cfg.merge_from_file(args.config_file)
- cfg.merge_from_list(args.opts)
- cfg.freeze()
- return args, cfg
- def setup_logger(base_path, exp_name):
- rootLogger = logging.getLogger()
- rootLogger.setLevel(logging.INFO)
- logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
- log_path = "{0}/{1}.log".format(base_path, exp_name)
- fileHandler = logging.FileHandler(log_path)
- fileHandler.setFormatter(logFormatter)
- rootLogger.addHandler(fileHandler)
- consoleHandler = logging.StreamHandler()
- consoleHandler.setFormatter(logFormatter)
- rootLogger.addHandler(consoleHandler)
- rootLogger.handlers[0].setLevel(logging.INFO)
- logging.info("log path: %s" % log_path)
- def cosine_loss(a, v, y, logloss=nn.BCELoss()):
- d = nn.functional.cosine_similarity(a, v)
- loss = logloss(d.unsqueeze(1), y)
- return loss
- def get_pose_params(mat_path):
- """Get pose parameters from mat file
- Args:
- mat_path (str): path of mat file
- Returns:
- pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
- """
- mat_dict = loadmat(mat_path)
- np_3dmm = mat_dict["coeff"]
- angles = np_3dmm[:, 224:227]
- translations = np_3dmm[:, 254:257]
- np_trans_params = mat_dict["transform_params"]
- crop = np_trans_params[:, -3:]
- pose_params = np.concatenate((angles, translations, crop), axis=1)
- return pose_params
- def sinusoidal_embedding(timesteps, dim):
- """
- Args:
- timesteps (_type_): (B,)
- dim (_type_): (C_embed)
- Returns:
- _type_: (B, C_embed)
- """
- # check input
- half = dim // 2
- timesteps = timesteps.float()
- # compute sinusoidal embedding
- sinusoid = torch.outer(
- timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
- )
- x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
- if dim % 2 != 0:
- x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
- return x
- def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
- """
- Args:
- audio_feat (np.ndarray): (250, 1024)
- start_idx (_type_): _description_
- num_frames (_type_): _description_
- """
- center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
- audio_window_list = []
- padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
- for center_idx in center_idx_list:
- cur_audio_window = []
- for i in range(center_idx - win_size, center_idx + win_size + 1):
- if i < 0:
- cur_audio_window.append(padding)
- elif i >= len(audio_feat):
- cur_audio_window.append(padding)
- else:
- cur_audio_window.append(audio_feat[i])
- cur_audio_win_array = np.stack(cur_audio_window, axis=0)
- audio_window_list.append(cur_audio_win_array)
- audio_window_array = np.stack(audio_window_list, axis=0)
- return audio_window_array
- def reshape_audio_feat(style_audio_all_raw, stride):
- """_summary_
- Args:
- style_audio_all_raw (_type_): (stride * L, C)
- stride (_type_): int
- Returns:
- _type_: (L, C * stride)
- """
- style_audio_all_raw = style_audio_all_raw[
- : style_audio_all_raw.shape[0] // stride * stride
- ]
- style_audio_all_raw = style_audio_all_raw.reshape(
- style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
- )
- style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
- return style_audio_all
- import random
- def get_derangement_tuple(n):
- while True:
- v = [i for i in range(n)]
- for j in range(n - 1, -1, -1):
- p = random.randint(0, j)
- if v[p] == j:
- break
- else:
- v[j], v[p] = v[p], v[j]
- else:
- if v[0] != 0:
- return tuple(v)
- def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
- left, top, right, bot = bbox
- width = right - left
- height = bot - top
- width_increase = max(
- increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
- )
- height_increase = max(
- increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
- )
- left_t = int(left - width_increase * width)
- top_t = int(top - height_increase * height)
- right_t = int(right + width_increase * width)
- bot_t = int(bot + height_increase * height)
- left_oob = -min(0, left_t)
- right_oob = right - min(right_t, w)
- top_oob = -min(0, top_t)
- bot_oob = bot - min(bot_t, h)
- if max(left_oob, right_oob, top_oob, bot_oob) > 0:
- max_w = max(left_oob, right_oob)
- max_h = max(top_oob, bot_oob)
- if max_w > max_h:
- return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
- else:
- return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
- else:
- return (left_t, top_t, right_t, bot_t)
- def crop_src_image(src_img, save_img, increase_ratio, detector=None):
- if detector is None:
- detector = dlib.get_frontal_face_detector()
- img = cv2.imread(src_img)
- faces = detector(img, 0)
- h, width, _ = img.shape
- if len(faces) > 0:
- bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
- l = bbox[3] - bbox[1]
- bbox[1] = bbox[1] - l * 0.1
- bbox[3] = bbox[3] - l * 0.1
- bbox[1] = max(0, bbox[1])
- bbox[3] = min(h, bbox[3])
- bbox = compute_aspect_preserved_bbox(
- tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
- )
- img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
- img = cv2.resize(img, (256, 256))
- cv2.imwrite(save_img, img)
- else:
- raise ValueError("No face detected in the input image")
- # img = cv2.resize(img, (256, 256))
- # cv2.imwrite(save_img, img)