utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. import os
  2. import argparse
  3. from collections import defaultdict
  4. import logging
  5. import pickle
  6. import json
  7. import numpy as np
  8. import torch
  9. from torch import nn
  10. from scipy.io import loadmat
  11. from configs.default import get_cfg_defaults
  12. import dlib
  13. import cv2
  14. def _reset_parameters(model):
  15. for p in model.parameters():
  16. if p.dim() > 1:
  17. nn.init.xavier_uniform_(p)
  18. def get_video_style(video_name, style_type):
  19. person_id, direction, emotion, level, *_ = video_name.split("_")
  20. if style_type == "id_dir_emo_level":
  21. style = "_".join([person_id, direction, emotion, level])
  22. elif style_type == "emotion":
  23. style = emotion
  24. elif style_type == "id":
  25. style = person_id
  26. else:
  27. raise ValueError("Unknown style type")
  28. return style
  29. def get_style_video_lists(video_list, style_type):
  30. style2video_list = defaultdict(list)
  31. for video in video_list:
  32. style = get_video_style(video, style_type)
  33. style2video_list[style].append(video)
  34. return style2video_list
  35. def get_face3d_clip(
  36. video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
  37. ):
  38. """_summary_
  39. Args:
  40. video_name (_type_): _description_
  41. video_root_dir (_type_): _description_
  42. num_frames (_type_): _description_
  43. start_idx (_type_): "random" , middle, int
  44. dtype (_type_, optional): _description_. Defaults to torch.float32.
  45. Raises:
  46. ValueError: _description_
  47. ValueError: _description_
  48. Returns:
  49. _type_: _description_
  50. """
  51. video_path = os.path.join(video_root_dir, video_name)
  52. if video_path[-3:] == "mat":
  53. face3d_all = loadmat(video_path)["coeff"]
  54. face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
  55. elif video_path[-3:] == "txt":
  56. face3d_exp = np.loadtxt(video_path)
  57. else:
  58. raise ValueError("Invalid 3DMM file extension")
  59. length = face3d_exp.shape[0]
  60. clip_num_frames = num_frames
  61. if start_idx == "random":
  62. clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
  63. elif start_idx == "middle":
  64. clip_start_idx = (length - clip_num_frames + 1) // 2
  65. elif isinstance(start_idx, int):
  66. clip_start_idx = start_idx
  67. else:
  68. raise ValueError(f"Invalid start_idx {start_idx}")
  69. face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
  70. face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
  71. return face3d_clip
  72. def get_video_style_clip(
  73. video_name,
  74. video_root_dir,
  75. style_max_len,
  76. start_idx="random",
  77. dtype=torch.float32,
  78. return_start_idx=False,
  79. ):
  80. video_path = os.path.join(video_root_dir, video_name)
  81. if video_path[-3:] == "mat":
  82. face3d_all = loadmat(video_path)["coeff"]
  83. face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
  84. elif video_path[-3:] == "txt":
  85. face3d_exp = np.loadtxt(video_path)
  86. else:
  87. raise ValueError("Invalid 3DMM file extension")
  88. face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
  89. length = face3d_exp.shape[0]
  90. if length >= style_max_len:
  91. clip_num_frames = style_max_len
  92. if start_idx == "random":
  93. clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
  94. elif start_idx == "middle":
  95. clip_start_idx = (length - clip_num_frames + 1) // 2
  96. elif isinstance(start_idx, int):
  97. clip_start_idx = start_idx
  98. else:
  99. raise ValueError(f"Invalid start_idx {start_idx}")
  100. face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
  101. pad_mask = torch.tensor([False] * style_max_len)
  102. else:
  103. clip_start_idx = None
  104. padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
  105. face3d_clip = torch.cat((face3d_exp, padding), dim=0)
  106. pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
  107. if return_start_idx:
  108. return face3d_clip, pad_mask, clip_start_idx
  109. else:
  110. return face3d_clip, pad_mask
  111. def get_video_style_clip_from_np(
  112. face3d_exp,
  113. style_max_len,
  114. start_idx="random",
  115. dtype=torch.float32,
  116. return_start_idx=False,
  117. ):
  118. face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
  119. length = face3d_exp.shape[0]
  120. if length >= style_max_len:
  121. clip_num_frames = style_max_len
  122. if start_idx == "random":
  123. clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
  124. elif start_idx == "middle":
  125. clip_start_idx = (length - clip_num_frames + 1) // 2
  126. elif isinstance(start_idx, int):
  127. clip_start_idx = start_idx
  128. else:
  129. raise ValueError(f"Invalid start_idx {start_idx}")
  130. face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
  131. pad_mask = torch.tensor([False] * style_max_len)
  132. else:
  133. clip_start_idx = None
  134. padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
  135. face3d_clip = torch.cat((face3d_exp, padding), dim=0)
  136. pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
  137. if return_start_idx:
  138. return face3d_clip, pad_mask, clip_start_idx
  139. else:
  140. return face3d_clip, pad_mask
  141. def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
  142. """
  143. Args:
  144. audio_feat (np.ndarray): (N, 1024)
  145. start_idx (_type_): _description_
  146. num_frames (_type_): _description_
  147. """
  148. center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
  149. audio_window_list = []
  150. padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
  151. for center_idx in center_idx_list:
  152. cur_audio_window = []
  153. for i in range(center_idx - win_size, center_idx + win_size + 1):
  154. if i < 0:
  155. cur_audio_window.append(padding)
  156. elif i >= len(audio_feat):
  157. cur_audio_window.append(padding)
  158. else:
  159. cur_audio_window.append(audio_feat[i])
  160. cur_audio_win_array = np.stack(cur_audio_window, axis=0)
  161. audio_window_list.append(cur_audio_win_array)
  162. audio_window_array = np.stack(audio_window_list, axis=0)
  163. return audio_window_array
  164. def setup_config():
  165. parser = argparse.ArgumentParser(description="voice2pose main program")
  166. parser.add_argument(
  167. "--config_file", default="", metavar="FILE", help="path to config file"
  168. )
  169. parser.add_argument(
  170. "--resume_from", type=str, default=None, help="the checkpoint to resume from"
  171. )
  172. parser.add_argument(
  173. "--test_only", action="store_true", help="perform testing and evaluation only"
  174. )
  175. parser.add_argument(
  176. "--demo_input", type=str, default=None, help="path to input for demo"
  177. )
  178. parser.add_argument(
  179. "--checkpoint", type=str, default=None, help="the checkpoint to test with"
  180. )
  181. parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
  182. parser.add_argument(
  183. "opts",
  184. help="Modify config options using the command-line",
  185. default=None,
  186. nargs=argparse.REMAINDER,
  187. )
  188. parser.add_argument(
  189. "--local_rank",
  190. type=int,
  191. help="local rank for DistributedDataParallel",
  192. )
  193. parser.add_argument(
  194. "--master_port",
  195. type=str,
  196. default="12345",
  197. )
  198. parser.add_argument(
  199. "--max_audio_len",
  200. type=int,
  201. default=450,
  202. help="max_audio_len for inference",
  203. )
  204. parser.add_argument(
  205. "--ddim_num_step",
  206. type=int,
  207. default=10,
  208. )
  209. parser.add_argument(
  210. "--inference_seed",
  211. type=int,
  212. default=1,
  213. )
  214. parser.add_argument(
  215. "--inference_sample_method",
  216. type=str,
  217. default="ddim",
  218. )
  219. args = parser.parse_args()
  220. cfg = get_cfg_defaults()
  221. cfg.merge_from_file(args.config_file)
  222. cfg.merge_from_list(args.opts)
  223. cfg.freeze()
  224. return args, cfg
  225. def setup_logger(base_path, exp_name):
  226. rootLogger = logging.getLogger()
  227. rootLogger.setLevel(logging.INFO)
  228. logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
  229. log_path = "{0}/{1}.log".format(base_path, exp_name)
  230. fileHandler = logging.FileHandler(log_path)
  231. fileHandler.setFormatter(logFormatter)
  232. rootLogger.addHandler(fileHandler)
  233. consoleHandler = logging.StreamHandler()
  234. consoleHandler.setFormatter(logFormatter)
  235. rootLogger.addHandler(consoleHandler)
  236. rootLogger.handlers[0].setLevel(logging.INFO)
  237. logging.info("log path: %s" % log_path)
  238. def cosine_loss(a, v, y, logloss=nn.BCELoss()):
  239. d = nn.functional.cosine_similarity(a, v)
  240. loss = logloss(d.unsqueeze(1), y)
  241. return loss
  242. def get_pose_params(mat_path):
  243. """Get pose parameters from mat file
  244. Args:
  245. mat_path (str): path of mat file
  246. Returns:
  247. pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
  248. """
  249. mat_dict = loadmat(mat_path)
  250. np_3dmm = mat_dict["coeff"]
  251. angles = np_3dmm[:, 224:227]
  252. translations = np_3dmm[:, 254:257]
  253. np_trans_params = mat_dict["transform_params"]
  254. crop = np_trans_params[:, -3:]
  255. pose_params = np.concatenate((angles, translations, crop), axis=1)
  256. return pose_params
  257. def sinusoidal_embedding(timesteps, dim):
  258. """
  259. Args:
  260. timesteps (_type_): (B,)
  261. dim (_type_): (C_embed)
  262. Returns:
  263. _type_: (B, C_embed)
  264. """
  265. # check input
  266. half = dim // 2
  267. timesteps = timesteps.float()
  268. # compute sinusoidal embedding
  269. sinusoid = torch.outer(
  270. timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
  271. )
  272. x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
  273. if dim % 2 != 0:
  274. x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
  275. return x
  276. def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
  277. """
  278. Args:
  279. audio_feat (np.ndarray): (250, 1024)
  280. start_idx (_type_): _description_
  281. num_frames (_type_): _description_
  282. """
  283. center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
  284. audio_window_list = []
  285. padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
  286. for center_idx in center_idx_list:
  287. cur_audio_window = []
  288. for i in range(center_idx - win_size, center_idx + win_size + 1):
  289. if i < 0:
  290. cur_audio_window.append(padding)
  291. elif i >= len(audio_feat):
  292. cur_audio_window.append(padding)
  293. else:
  294. cur_audio_window.append(audio_feat[i])
  295. cur_audio_win_array = np.stack(cur_audio_window, axis=0)
  296. audio_window_list.append(cur_audio_win_array)
  297. audio_window_array = np.stack(audio_window_list, axis=0)
  298. return audio_window_array
  299. def reshape_audio_feat(style_audio_all_raw, stride):
  300. """_summary_
  301. Args:
  302. style_audio_all_raw (_type_): (stride * L, C)
  303. stride (_type_): int
  304. Returns:
  305. _type_: (L, C * stride)
  306. """
  307. style_audio_all_raw = style_audio_all_raw[
  308. : style_audio_all_raw.shape[0] // stride * stride
  309. ]
  310. style_audio_all_raw = style_audio_all_raw.reshape(
  311. style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
  312. )
  313. style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
  314. return style_audio_all
  315. import random
  316. def get_derangement_tuple(n):
  317. while True:
  318. v = [i for i in range(n)]
  319. for j in range(n - 1, -1, -1):
  320. p = random.randint(0, j)
  321. if v[p] == j:
  322. break
  323. else:
  324. v[j], v[p] = v[p], v[j]
  325. else:
  326. if v[0] != 0:
  327. return tuple(v)
  328. def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
  329. left, top, right, bot = bbox
  330. width = right - left
  331. height = bot - top
  332. width_increase = max(
  333. increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
  334. )
  335. height_increase = max(
  336. increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
  337. )
  338. left_t = int(left - width_increase * width)
  339. top_t = int(top - height_increase * height)
  340. right_t = int(right + width_increase * width)
  341. bot_t = int(bot + height_increase * height)
  342. left_oob = -min(0, left_t)
  343. right_oob = right - min(right_t, w)
  344. top_oob = -min(0, top_t)
  345. bot_oob = bot - min(bot_t, h)
  346. if max(left_oob, right_oob, top_oob, bot_oob) > 0:
  347. max_w = max(left_oob, right_oob)
  348. max_h = max(top_oob, bot_oob)
  349. if max_w > max_h:
  350. return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
  351. else:
  352. return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
  353. else:
  354. return (left_t, top_t, right_t, bot_t)
  355. def crop_src_image(src_img, save_img, increase_ratio, detector=None):
  356. if detector is None:
  357. detector = dlib.get_frontal_face_detector()
  358. img = cv2.imread(src_img)
  359. faces = detector(img, 0)
  360. h, width, _ = img.shape
  361. if len(faces) > 0:
  362. bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
  363. l = bbox[3] - bbox[1]
  364. bbox[1] = bbox[1] - l * 0.1
  365. bbox[3] = bbox[3] - l * 0.1
  366. bbox[1] = max(0, bbox[1])
  367. bbox[3] = min(h, bbox[3])
  368. bbox = compute_aspect_preserved_bbox(
  369. tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
  370. )
  371. img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
  372. img = cv2.resize(img, (256, 256))
  373. cv2.imwrite(save_img, img)
  374. else:
  375. raise ValueError("No face detected in the input image")
  376. # img = cv2.resize(img, (256, 256))
  377. # cv2.imwrite(save_img, img)