inference_for_demo_video.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import argparse
  2. import json
  3. import os
  4. import shutil
  5. import subprocess
  6. import numpy as np
  7. import torch
  8. import torchaudio
  9. from scipy.io import loadmat
  10. from transformers import Wav2Vec2Processor
  11. from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
  12. from configs.default import get_cfg_defaults
  13. from core.networks.diffusion_net import DiffusionNet
  14. from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
  15. from core.utils import (
  16. crop_src_image,
  17. get_pose_params,
  18. get_video_style_clip,
  19. get_wav2vec_audio_window,
  20. )
  21. from generators.utils import get_netG, render_video
  22. @torch.no_grad()
  23. def get_diff_net(cfg, device):
  24. diff_net = DiffusionNet(
  25. cfg=cfg,
  26. net=NoisePredictor(cfg),
  27. var_sched=VarianceSchedule(
  28. num_steps=cfg.DIFFUSION.SCHEDULE.NUM_STEPS,
  29. beta_1=cfg.DIFFUSION.SCHEDULE.BETA_1,
  30. beta_T=cfg.DIFFUSION.SCHEDULE.BETA_T,
  31. mode=cfg.DIFFUSION.SCHEDULE.MODE,
  32. ),
  33. )
  34. checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT, map_location=device)
  35. model_state_dict = checkpoint["model_state_dict"]
  36. diff_net_dict = {
  37. k[9:]: v for k, v in model_state_dict.items() if k[:9] == "diff_net."
  38. }
  39. diff_net.load_state_dict(diff_net_dict, strict=True)
  40. diff_net.eval()
  41. return diff_net
  42. @torch.no_grad()
  43. def get_audio_feat(wav_path, output_name, wav2vec_model):
  44. audio_feat_dir = os.path.dirname(audio_feat_path)
  45. pass
  46. @torch.no_grad()
  47. def inference_one_video(
  48. cfg,
  49. audio_path,
  50. style_clip_path,
  51. pose_path,
  52. output_path,
  53. diff_net,
  54. device,
  55. max_audio_len=None,
  56. sample_method="ddim",
  57. ddim_num_step=10,
  58. ):
  59. audio_raw = audio_data = np.load(audio_path)
  60. if max_audio_len is not None:
  61. audio_raw = audio_raw[: max_audio_len * 50]
  62. gen_num_frames = len(audio_raw) // 2
  63. audio_win_array = get_wav2vec_audio_window(
  64. audio_raw,
  65. start_idx=0,
  66. num_frames=gen_num_frames,
  67. win_size=cfg.WIN_SIZE,
  68. )
  69. audio_win = torch.tensor(audio_win_array).to(device)
  70. audio = audio_win.unsqueeze(0)
  71. # the second parameter is "" because of bad interface design...
  72. style_clip_raw, style_pad_mask_raw = get_video_style_clip(
  73. style_clip_path, "", style_max_len=256, start_idx=0
  74. )
  75. style_clip = style_clip_raw.unsqueeze(0).to(device)
  76. style_pad_mask = (
  77. style_pad_mask_raw.unsqueeze(0).to(device)
  78. if style_pad_mask_raw is not None
  79. else None
  80. )
  81. gen_exp_stack = diff_net.sample(
  82. audio,
  83. style_clip,
  84. style_pad_mask,
  85. output_dim=cfg.DATASET.FACE3D_DIM,
  86. use_cf_guidance=cfg.CF_GUIDANCE.INFERENCE,
  87. cfg_scale=cfg.CF_GUIDANCE.SCALE,
  88. sample_method=sample_method,
  89. ddim_num_step=ddim_num_step,
  90. )
  91. gen_exp = gen_exp_stack[0].cpu().numpy()
  92. pose_ext = pose_path[-3:]
  93. pose = None
  94. pose = get_pose_params(pose_path)
  95. # (L, 9)
  96. selected_pose = None
  97. if len(pose) >= len(gen_exp):
  98. selected_pose = pose[: len(gen_exp)]
  99. else:
  100. selected_pose = pose[-1].unsqueeze(0).repeat(len(gen_exp), 1)
  101. selected_pose[: len(pose)] = pose
  102. gen_exp_pose = np.concatenate((gen_exp, selected_pose), axis=1)
  103. np.save(output_path, gen_exp_pose)
  104. return output_path
  105. if __name__ == "__main__":
  106. parser = argparse.ArgumentParser(description="inference for demo")
  107. parser.add_argument("--wav_path", type=str, default="", help="path for wav")
  108. parser.add_argument("--image_path", type=str, default="", help="path for image")
  109. parser.add_argument("--disable_img_crop", dest="img_crop", action="store_false")
  110. parser.set_defaults(img_crop=True)
  111. parser.add_argument(
  112. "--style_clip_path", type=str, default="", help="path for style_clip_mat"
  113. )
  114. parser.add_argument("--pose_path", type=str, default="", help="path for pose")
  115. parser.add_argument(
  116. "--max_gen_len",
  117. type=int,
  118. default=1000,
  119. help="The maximum length (seconds) limitation for generating videos",
  120. )
  121. parser.add_argument(
  122. "--cfg_scale",
  123. type=float,
  124. default=1.0,
  125. help="The scale of classifier-free guidance",
  126. )
  127. parser.add_argument(
  128. "--output_name",
  129. type=str,
  130. default="test",
  131. )
  132. parser.add_argument(
  133. "--device",
  134. type=str,
  135. default="cuda",
  136. )
  137. args = parser.parse_args()
  138. if args.device == "cuda" and not torch.cuda.is_available():
  139. print("CUDA is not available, set --device=cpu to use CPU.")
  140. exit(1)
  141. device = torch.device(args.device)
  142. cfg = get_cfg_defaults()
  143. cfg.CF_GUIDANCE.SCALE = args.cfg_scale
  144. cfg.freeze()
  145. tmp_dir = f"tmp/{args.output_name}"
  146. os.makedirs(tmp_dir, exist_ok=True)
  147. # get audio in 16000Hz
  148. wav_16k_path = os.path.join(tmp_dir, f"{args.output_name}_16K.wav")
  149. command = f"ffmpeg -y -i {args.wav_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {wav_16k_path}"
  150. subprocess.run(command.split())
  151. # get wav2vec feat from audio
  152. wav2vec_processor = Wav2Vec2Processor.from_pretrained(
  153. "jonatasgrosman/wav2vec2-large-xlsr-53-english"
  154. )
  155. wav2vec_model = (
  156. Wav2Vec2Model.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
  157. .eval()
  158. .to(device)
  159. )
  160. speech_array, sampling_rate = torchaudio.load(wav_16k_path)
  161. audio_data = speech_array.squeeze().numpy()
  162. inputs = wav2vec_processor(
  163. audio_data, sampling_rate=16_000, return_tensors="pt", padding=True
  164. )
  165. with torch.no_grad():
  166. audio_embedding = wav2vec_model(
  167. inputs.input_values.to(device), return_dict=False
  168. )[0]
  169. audio_feat_path = os.path.join(tmp_dir, f"{args.output_name}_wav2vec.npy")
  170. np.save(audio_feat_path, audio_embedding[0].cpu().numpy())
  171. # get src image
  172. src_img_path = os.path.join(tmp_dir, "src_img.png")
  173. if args.img_crop:
  174. crop_src_image(args.image_path, src_img_path, 0.4)
  175. else:
  176. shutil.copy(args.image_path, src_img_path)
  177. with torch.no_grad():
  178. # get diff model and load checkpoint
  179. diff_net = get_diff_net(cfg, device).to(device)
  180. # generate face motion
  181. face_motion_path = os.path.join(tmp_dir, f"{args.output_name}_facemotion.npy")
  182. inference_one_video(
  183. cfg,
  184. audio_feat_path,
  185. args.style_clip_path,
  186. args.pose_path,
  187. face_motion_path,
  188. diff_net,
  189. device,
  190. max_audio_len=args.max_gen_len,
  191. )
  192. # get renderer
  193. renderer = get_netG("checkpoints/renderer.pt", device)
  194. # render video
  195. output_video_path = f"output_video/{args.output_name}.mp4"
  196. render_video(
  197. renderer,
  198. src_img_path,
  199. face_motion_path,
  200. wav_16k_path,
  201. output_video_path,
  202. device,
  203. fps=25,
  204. no_move=False,
  205. )
  206. # add watermark
  207. # if you want to generate videos with no watermark (for evaluation), remove this code block.
  208. #no_watermark_video_path = f"{output_video_path}-no_watermark.mp4"
  209. #shutil.move(output_video_path, no_watermark_video_path)
  210. #os.system(
  211. # f'ffmpeg -y -i {no_watermark_video_path} -vf "movie=media/watermark.png,scale= 120: 36[watermask]; [in] [watermask] overlay=140:220 [out]" {output_video_path}'
  212. #)
  213. #os.remove(no_watermark_video_path)