Browse Source

add files

永徽 1 year ago
parent
commit
50116789a2
100 changed files with 2462 additions and 1 deletions
  1. 100 1
      README.md
  2. 0 0
      checkpoints/.gitkeep
  3. 91 0
      configs/default.py
  4. 14 0
      core/networks/__init__.py
  5. 340 0
      core/networks/diffusion_net.py
  6. 131 0
      core/networks/diffusion_util.py
  7. 240 0
      core/networks/disentangle_decoder.py
  8. 156 0
      core/networks/dynamic_conv.py
  9. 178 0
      core/networks/dynamic_fc_decoder.py
  10. 50 0
      core/networks/dynamic_linear.py
  11. 309 0
      core/networks/generator.py
  12. 51 0
      core/networks/mish.py
  13. 53 0
      core/networks/self_attention_pooling.py
  14. 293 0
      core/networks/transformer.py
  15. 456 0
      core/utils.py
  16. BIN
      data/audio/German1.wav
  17. BIN
      data/audio/German2.wav
  18. BIN
      data/audio/German3.wav
  19. BIN
      data/audio/German4.wav
  20. BIN
      data/audio/acknowledgement_chinese.m4a
  21. BIN
      data/audio/acknowledgement_english.m4a
  22. BIN
      data/audio/chinese1_haierlizhi.wav
  23. BIN
      data/audio/chinese2_guanyu.wav
  24. BIN
      data/audio/french1.wav
  25. BIN
      data/audio/french2.wav
  26. BIN
      data/audio/french3.wav
  27. BIN
      data/audio/italian1.wav
  28. BIN
      data/audio/italian2.wav
  29. BIN
      data/audio/italian3.wav
  30. BIN
      data/audio/japan1.wav
  31. BIN
      data/audio/japan2.wav
  32. BIN
      data/audio/japan3.wav
  33. BIN
      data/audio/korean1.wav
  34. BIN
      data/audio/korean2.wav
  35. BIN
      data/audio/korean3.wav
  36. BIN
      data/audio/noisy_audio_cafeter_snr_0.wav
  37. BIN
      data/audio/noisy_audio_meeting_snr_0.wav
  38. BIN
      data/audio/noisy_audio_meeting_snr_10.wav
  39. BIN
      data/audio/noisy_audio_meeting_snr_20.wav
  40. BIN
      data/audio/noisy_audio_narrative.wav
  41. BIN
      data/audio/noisy_audio_office_snr_0.wav
  42. BIN
      data/audio/out_of_domain_narrative.wav
  43. BIN
      data/audio/spanish1.wav
  44. BIN
      data/audio/spanish2.wav
  45. BIN
      data/audio/spanish3.wav
  46. BIN
      data/pose/RichardShelby_front_neutral_level1_001.mat
  47. BIN
      data/src_img/cropped/chpa5.png
  48. BIN
      data/src_img/cropped/cut_img.png
  49. BIN
      data/src_img/cropped/f30.png
  50. BIN
      data/src_img/cropped/menglu2.png
  51. BIN
      data/src_img/cropped/nscu2.png
  52. BIN
      data/src_img/cropped/zp1.png
  53. BIN
      data/src_img/cropped/zt12.png
  54. BIN
      data/src_img/uncropped/face3.png
  55. BIN
      data/src_img/uncropped/male_face.png
  56. BIN
      data/src_img/uncropped/uncut_src_img.jpg
  57. BIN
      data/style_clip/3DMM/M030_front_angry_level3_001.mat
  58. BIN
      data/style_clip/3DMM/M030_front_contempt_level3_001.mat
  59. BIN
      data/style_clip/3DMM/M030_front_disgusted_level3_001.mat
  60. BIN
      data/style_clip/3DMM/M030_front_fear_level3_001.mat
  61. BIN
      data/style_clip/3DMM/M030_front_happy_level3_001.mat
  62. BIN
      data/style_clip/3DMM/M030_front_neutral_level1_001.mat
  63. BIN
      data/style_clip/3DMM/M030_front_sad_level3_001.mat
  64. BIN
      data/style_clip/3DMM/M030_front_surprised_level3_001.mat
  65. BIN
      data/style_clip/3DMM/W009_front_angry_level3_001.mat
  66. BIN
      data/style_clip/3DMM/W009_front_contempt_level3_001.mat
  67. BIN
      data/style_clip/3DMM/W009_front_disgusted_level3_001.mat
  68. BIN
      data/style_clip/3DMM/W009_front_fear_level3_001.mat
  69. BIN
      data/style_clip/3DMM/W009_front_happy_level3_001.mat
  70. BIN
      data/style_clip/3DMM/W009_front_neutral_level1_001.mat
  71. BIN
      data/style_clip/3DMM/W009_front_sad_level3_001.mat
  72. BIN
      data/style_clip/3DMM/W009_front_surprised_level3_001.mat
  73. BIN
      data/style_clip/3DMM/W011_front_angry_level3_001.mat
  74. BIN
      data/style_clip/3DMM/W011_front_contempt_level3_001.mat
  75. BIN
      data/style_clip/3DMM/W011_front_disgusted_level3_001.mat
  76. BIN
      data/style_clip/3DMM/W011_front_fear_level3_001.mat
  77. BIN
      data/style_clip/3DMM/W011_front_happy_level3_001.mat
  78. BIN
      data/style_clip/3DMM/W011_front_neutral_level1_001.mat
  79. BIN
      data/style_clip/3DMM/W011_front_sad_level3_001.mat
  80. BIN
      data/style_clip/3DMM/W011_front_surprised_level3_001.mat
  81. BIN
      data/style_clip/video/M030_front_angry_level3_001.mp4
  82. BIN
      data/style_clip/video/M030_front_contempt_level3_001.mp4
  83. BIN
      data/style_clip/video/M030_front_disgusted_level3_001.mp4
  84. BIN
      data/style_clip/video/M030_front_fear_level3_001.mp4
  85. BIN
      data/style_clip/video/M030_front_happy_level3_001.mp4
  86. BIN
      data/style_clip/video/M030_front_neutral_level1_001.mp4
  87. BIN
      data/style_clip/video/M030_front_sad_level3_001.mp4
  88. BIN
      data/style_clip/video/M030_front_surprised_level3_001.mp4
  89. BIN
      data/style_clip/video/W009_front_angry_level3_001.mp4
  90. BIN
      data/style_clip/video/W009_front_contempt_level3_001.mp4
  91. BIN
      data/style_clip/video/W009_front_disgusted_level3_001.mp4
  92. BIN
      data/style_clip/video/W009_front_fear_level3_001.mp4
  93. BIN
      data/style_clip/video/W009_front_happy_level3_001.mp4
  94. BIN
      data/style_clip/video/W009_front_neutral_level1_001.mp4
  95. BIN
      data/style_clip/video/W009_front_sad_level3_001.mp4
  96. BIN
      data/style_clip/video/W009_front_surprised_level3_001.mp4
  97. BIN
      data/style_clip/video/W011_front_angry_level3_001.mp4
  98. BIN
      data/style_clip/video/W011_front_contempt_level3_001.mp4
  99. BIN
      data/style_clip/video/W011_front_disgusted_level3_001.mp4
  100. BIN
      data/style_clip/video/W011_front_fear_level3_001.mp4

+ 100 - 1
README.md

@@ -1 +1,100 @@
-# dreamtalk
+# DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models
+
+<a href='https://dreamtalk-project.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2312.09767'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/VF4vlE6ZqWQ)
+
+DreamTalk is a diffusion-based audio-driven expressive talking head generation framework that can produce high-quality talking head videos across diverse speaking styles. DreamTalk exhibits robust performance with a diverse array of inputs, including songs, speech in multiple languages, noisy audio, and out-of-domain portraits.
+
+![figure1](media/teaser.gif "teaser")
+
+## News
+- __[2023.12]__ Release inference code and pretrained checkpoint.
+
+## Installation
+
+```
+conda create -n dreamtalk python=3.7.0
+conda activate dreamtalk
+pip install -r requirements.txt
+conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
+conda update ffmpeg
+
+pip install urllib3==1.26.6
+pip install transformers==4.28.1
+pip install dlib
+```
+
+## Download Checkpoints
+Download the checkpoint of the denoising network: 
+* [ModelScope](https://modelscope.cn/models/damo/dreamtalk/file/view/master/checkpoints%2Fdenoising_network.pth?status=2)
+
+
+Download the checkpoint of the renderer: 
+* [ModelScope](https://modelscope.cn/models/damo/dreamtalk/file/view/master/checkpoints%2Frenderer.pt?status=2)
+
+Put the downloaded checkpoints into `checkpoints` folder.
+
+
+## Inference
+Run the script:
+
+```
+python inference_for_demo_video.py \
+--wav_path data/audio/acknowledgement_english.m4a \
+--style_clip_path data/style_clip/3DMM/M030_front_neutral_level1_001.mat \
+--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
+--image_path data/src_img/uncropped/male_face.png \
+--cfg_scale 1.0 \
+--max_gen_len 30 \
+--output_name acknowledgement_english@M030_front_neutral_level1_001@male_face
+```
+
+`wav_path` specifies the input audio. The input audio file extensions such as wav, mp3, m4a, and mp4 (video with sound) should all be compatible.
+
+`style_clip_path` specifies the reference speaking style and `pose_path` specifies head pose. They are 3DMM paramenter sequences extracted from reference videos. You can follow [PIRenderer](https://github.com/RenYurui/PIRender) to extract 3DMM parameters from your own videos. Note that the video frame rate should be 25 FPS. Besides, videos used for head pose reference should be first cropped to $256\times256$ using scripts in [FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
+
+`image_path` specifies the input portrait. Its resolution should be larger than $256\times256$. Frontal portraits, with the face directly facing forward and not tilted to one side, usually achieve satisfactory results. The input portrait will be cropped to $256\times256$. If your portrait is already cropped to $256\times256$ and you want to disable cropping, use option `--disable_img_crop` like this:
+
+```
+python inference_for_demo_video.py \
+--wav_path data/audio/acknowledgement_chinese.m4a \
+--style_clip_path data/style_clip/3DMM/M030_front_surprised_level3_001.mat \
+--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
+--image_path data/src_img/cropped/zp1.png \
+--disable_img_crop \
+--cfg_scale 1.0 \
+--max_gen_len 30 \
+--output_name acknowledgement_chinese@M030_front_surprised_level3_001@zp1
+```
+
+`cfg_scale` controls the scale of classifer-free guidance. It can adjust the intensity of speaking styles.
+
+`max_gen_len` is the maximum video generation duration, measured in seconds. If the input audio exceeds this length, it will be truncated.
+
+The generated video will be named `$(output_name).mp4` and put in the output_video folder. Intermediate results, including the cropped portrait, will be in the `tmp/$(output_name)` folder.
+
+Sample inputs are presented in `data` folder. Due to copyright issues, we are unable to include the songs we have used in this folder.
+
+
+## Acknowledgements
+
+We extend our heartfelt thanks for the invaluable contributions made by preceding works to the development of DreamTalk. This includes, but is not limited to:
+[PIRenderer](https://github.com/RenYurui/PIRender)
+,[AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face)
+,[StyleTalk](https://github.com/FuxiVirtualHuman/styletalk)
+,[Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
+,[Wav2vec2.0](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-english)
+,[diffusion-point-cloud](https://github.com/luost26/diffusion-point-cloud)
+,[FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing). We are dedicated to advancing upon these foundational works with the utmost respect for their original contributions.
+
+## Citation
+If you find this codebase useful for your research, please use the following entry.
+```BibTeX
+@article{ma2023dreamtalk,
+  title={DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models},
+  author={Ma, Yifeng and Zhang, Shiwei and Wang, Jiayu and Wang, Xiang and Zhang, Yingya and Deng, Zhidong},
+  journal={arXiv preprint arXiv:2312.09767},
+  year={2023}
+}
+```
+
+

+ 0 - 0
checkpoints/.gitkeep


+ 91 - 0
configs/default.py

@@ -0,0 +1,91 @@
+from yacs.config import CfgNode as CN
+
+
+_C = CN()
+_C.TAG = "style_id_emotion"
+_C.DECODER_TYPE = "DisentangleDecoder"
+_C.CONTENT_ENCODER_TYPE = "ContentW2VEncoder"
+_C.STYLE_ENCODER_TYPE = "StyleEncoder"
+
+_C.DIFFNET_TYPE = "DiffusionNet"
+
+_C.WIN_SIZE = 5
+_C.D_MODEL = 256
+
+_C.DATASET = CN()
+_C.DATASET.FACE3D_DIM = 64
+_C.DATASET.NUM_FRAMES = 64
+_C.DATASET.STYLE_MAX_LEN = 256
+
+_C.TRAIN = CN()
+_C.TRAIN.FACE3D_LATENT = CN()
+_C.TRAIN.FACE3D_LATENT.TYPE = "face3d"
+
+_C.DIFFUSION = CN()
+_C.DIFFUSION.PREDICT_WHAT = "x0"  # noise | x0
+_C.DIFFUSION.SCHEDULE = CN()
+_C.DIFFUSION.SCHEDULE.NUM_STEPS = 1000
+_C.DIFFUSION.SCHEDULE.BETA_1 = 1e-4
+_C.DIFFUSION.SCHEDULE.BETA_T = 0.02
+_C.DIFFUSION.SCHEDULE.MODE = "linear"
+
+_C.CONTENT_ENCODER = CN()
+_C.CONTENT_ENCODER.d_model = _C.D_MODEL
+_C.CONTENT_ENCODER.nhead = 8
+_C.CONTENT_ENCODER.num_encoder_layers = 3
+_C.CONTENT_ENCODER.dim_feedforward = 4 * _C.D_MODEL
+_C.CONTENT_ENCODER.dropout = 0.1
+_C.CONTENT_ENCODER.activation = "relu"
+_C.CONTENT_ENCODER.normalize_before = False
+_C.CONTENT_ENCODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
+
+_C.STYLE_ENCODER = CN()
+_C.STYLE_ENCODER.d_model = _C.D_MODEL
+_C.STYLE_ENCODER.nhead = 8
+_C.STYLE_ENCODER.num_encoder_layers = 3
+_C.STYLE_ENCODER.dim_feedforward = 4 * _C.D_MODEL
+_C.STYLE_ENCODER.dropout = 0.1
+_C.STYLE_ENCODER.activation = "relu"
+_C.STYLE_ENCODER.normalize_before = False
+_C.STYLE_ENCODER.pos_embed_len = _C.DATASET.STYLE_MAX_LEN
+_C.STYLE_ENCODER.aggregate_method = (
+    "self_attention_pooling"  # average | self_attention_pooling
+)
+# _C.STYLE_ENCODER.input_dim = _C.DATASET.FACE3D_DIM
+
+_C.DECODER = CN()
+_C.DECODER.d_model = _C.D_MODEL
+_C.DECODER.nhead = 8
+_C.DECODER.num_decoder_layers = 3
+_C.DECODER.dim_feedforward = 4 * _C.D_MODEL
+_C.DECODER.dropout = 0.1
+_C.DECODER.activation = "relu"
+_C.DECODER.normalize_before = False
+_C.DECODER.return_intermediate_dec = False
+_C.DECODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
+_C.DECODER.network_type = "TransformerDecoder"
+_C.DECODER.dynamic_K = None
+_C.DECODER.dynamic_ratio = None
+# _C.DECODER.output_dim = _C.DATASET.FACE3D_DIM
+# LSFM basis:
+# _C.DECODER.upper_face3d_indices = tuple(list(range(19)) + list(range(46, 51)))
+# _C.DECODER.lower_face3d_indices = tuple(range(19, 46))
+# BFM basis:
+# fmt: off
+_C.DECODER.upper_face3d_indices = [6, 8, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] 
+# fmt: on
+_C.DECODER.lower_face3d_indices = [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14]
+
+_C.CF_GUIDANCE = CN()
+_C.CF_GUIDANCE.TRAINING = True
+_C.CF_GUIDANCE.INFERENCE = True
+_C.CF_GUIDANCE.NULL_PROB = 0.1
+_C.CF_GUIDANCE.SCALE = 1.0
+
+_C.INFERENCE = CN()
+_C.INFERENCE.CHECKPOINT = "checkpoints/denoising_network.pth"
+
+
+def get_cfg_defaults():
+    """Get a yacs CfgNode object with default values for my_project."""
+    return _C.clone()

+ 14 - 0
core/networks/__init__.py

@@ -0,0 +1,14 @@
+from core.networks.generator import (
+    StyleEncoder,
+    Decoder,
+    ContentW2VEncoder,
+)
+from core.networks.disentangle_decoder import DisentangleDecoder
+
+
+def get_network(name: str):
+    obj = globals().get(name)
+    if obj is None:
+        raise KeyError("Unknown Network: %s" % name)
+    else:
+        return obj

+ 340 - 0
core/networks/diffusion_net.py

@@ -0,0 +1,340 @@
+import math
+import torch
+import torch.nn.functional as F
+from torch.nn import Module
+from core.networks.diffusion_util import VarianceSchedule
+import numpy as np
+
+
+def face3d_raw_to_norm(face3d_raw, exp_min, exp_max):
+    """
+
+    Args:
+        face3d_raw (_type_): (B, L, C_face3d)
+        exp_min (_type_): (C_face3d)
+        exp_max (_type_): (C_face3d)
+
+    Returns:
+        _type_: (B, L, C_face3d) in [-1, 1]
+    """
+    exp_min_expand = exp_min[None, None, :]
+    exp_max_expand = exp_max[None, None, :]
+    face3d_norm_01 = (face3d_raw - exp_min_expand) / (exp_max_expand - exp_min_expand)
+    face3d_norm = face3d_norm_01 * 2 - 1
+    return face3d_norm
+
+
+def face3d_norm_to_raw(face3d_norm, exp_min, exp_max):
+    """
+
+    Args:
+        face3d_norm (_type_): (B, L, C_face3d)
+        exp_min (_type_): (C_face3d)
+        exp_max (_type_): (C_face3d)
+
+    Returns:
+        _type_: (B, L, C_face3d)
+    """
+    exp_min_expand = exp_min[None, None, :]
+    exp_max_expand = exp_max[None, None, :]
+    face3d_norm_01 = (face3d_norm + 1) / 2
+    face3d_raw = face3d_norm_01 * (exp_max_expand - exp_min_expand) + exp_min_expand
+    return face3d_raw
+
+
+class DiffusionNet(Module):
+    def __init__(self, cfg, net, var_sched: VarianceSchedule):
+        super().__init__()
+        self.cfg = cfg
+        self.net = net
+        self.var_sched = var_sched
+        self.face3d_latent_type = self.cfg.TRAIN.FACE3D_LATENT.TYPE
+        self.predict_what = self.cfg.DIFFUSION.PREDICT_WHAT
+
+        if self.cfg.CF_GUIDANCE.TRAINING:
+            null_style_clip = torch.zeros(
+                self.cfg.DATASET.STYLE_MAX_LEN, self.cfg.DATASET.FACE3D_DIM
+            )
+            self.register_buffer("null_style_clip", null_style_clip)
+
+            null_pad_mask = torch.tensor([False] * self.cfg.DATASET.STYLE_MAX_LEN)
+            self.register_buffer("null_pad_mask", null_pad_mask)
+
+    def _face3d_to_latent(self, face3d):
+        latent = None
+        if self.face3d_latent_type == "face3d":
+            latent = face3d
+        elif self.face3d_latent_type == "normalized_face3d":
+            latent = face3d_raw_to_norm(
+                face3d, exp_min=self.exp_min, exp_max=self.exp_max
+            )
+        else:
+            raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
+        return latent
+
+    def _latent_to_face3d(self, latent):
+        face3d = None
+        if self.face3d_latent_type == "face3d":
+            face3d = latent
+        elif self.face3d_latent_type == "normalized_face3d":
+            latent = torch.clamp(latent, min=-1, max=1)
+            face3d = face3d_norm_to_raw(
+                latent, exp_min=self.exp_min, exp_max=self.exp_max
+            )
+        else:
+            raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
+        return face3d
+
+    def ddim_sample(
+        self,
+        audio,
+        style_clip,
+        style_pad_mask,
+        output_dim,
+        flexibility=0.0,
+        ret_traj=False,
+        use_cf_guidance=False,
+        cfg_scale=2.0,
+        ddim_num_step=50,
+        ready_style_code=None,
+    ):
+        """
+
+        Args:
+            audio (_type_): (B, L, W) or (B, L, W, C)
+            style_clip (_type_): (B, L_clipmax, C_face3d)
+            style_pad_mask : (B, L_clipmax)
+            pose_dim (_type_): int
+            flexibility (float, optional): _description_. Defaults to 0.0.
+            ret_traj (bool, optional): _description_. Defaults to False.
+
+
+        Returns:
+            _type_: (B, L, C_face)
+        """
+        if self.predict_what != "x0":
+            raise NotImplementedError(self.predict_what)
+
+        if ready_style_code is not None and use_cf_guidance:
+            raise NotImplementedError("not implement cfg for ready style code")
+
+        c = self.var_sched.num_steps // ddim_num_step
+        time_steps = torch.tensor(
+            np.asarray(list(range(0, self.var_sched.num_steps, c))) + 1
+        )
+        assert len(time_steps) == ddim_num_step
+        prev_time_steps = torch.cat((torch.tensor([0]), time_steps[:-1]))
+
+        batch_size, output_len = audio.shape[:2]
+        # batch_size = context.size(0)
+        context = {
+            "audio": audio,
+            "style_clip": style_clip,
+            "style_pad_mask": style_pad_mask,
+            "ready_style_code": ready_style_code,
+        }
+        if use_cf_guidance:
+            uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
+                batch_size, 1, 1
+            )
+            uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
+
+            context_double = {
+                "audio": torch.cat([audio] * 2, dim=0),
+                "style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
+                "style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
+                "ready_style_code": None
+                if ready_style_code is None
+                else torch.cat(
+                    [
+                        ready_style_code,
+                        self.net.style_encoder(uncond_style_clip, uncond_pad_mask),
+                    ],
+                    dim=0,
+                ),
+            }
+
+        x_t = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
+
+        for idx in list(range(ddim_num_step))[::-1]:
+            t = time_steps[idx]
+            t_prev = prev_time_steps[idx]
+            ddim_alpha = self.var_sched.alpha_bars[t]
+            ddim_alpha_prev = self.var_sched.alpha_bars[t_prev]
+
+            t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
+            if use_cf_guidance:
+                x_t_double = torch.cat([x_t] * 2, dim=0)
+                t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
+                cond_output, uncond_output = self.net(
+                    x_t_double, t=t_tensor_double, **context_double
+                ).chunk(2)
+                diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
+            else:
+                diff_output = self.net(x_t, t=t_tensor, **context)
+
+            pred_x0 = diff_output
+            eps = (x_t - torch.sqrt(ddim_alpha) * pred_x0) / torch.sqrt(1 - ddim_alpha)
+            c1 = torch.sqrt(ddim_alpha_prev)
+            c2 = torch.sqrt(1 - ddim_alpha_prev)
+
+            x_t = c1 * pred_x0 + c2 * eps
+
+        latent_output = x_t
+        face3d_output = self._latent_to_face3d(latent_output)
+        return face3d_output
+
+    def sample(
+        self,
+        audio,
+        style_clip,
+        style_pad_mask,
+        output_dim,
+        flexibility=0.0,
+        ret_traj=False,
+        use_cf_guidance=False,
+        cfg_scale=2.0,
+        sample_method="ddpm",
+        ddim_num_step=50,
+        ready_style_code=None,
+    ):
+        # sample_method = kwargs["sample_method"]
+        if sample_method == "ddpm":
+            if ready_style_code is not None:
+                raise NotImplementedError("ready style code in ddpm")
+            return self.ddpm_sample(
+                audio,
+                style_clip,
+                style_pad_mask,
+                output_dim,
+                flexibility=flexibility,
+                ret_traj=ret_traj,
+                use_cf_guidance=use_cf_guidance,
+                cfg_scale=cfg_scale,
+            )
+        elif sample_method == "ddim":
+            return self.ddim_sample(
+                audio,
+                style_clip,
+                style_pad_mask,
+                output_dim,
+                flexibility=flexibility,
+                ret_traj=ret_traj,
+                use_cf_guidance=use_cf_guidance,
+                cfg_scale=cfg_scale,
+                ddim_num_step=ddim_num_step,
+                ready_style_code=ready_style_code,
+            )
+
+    def ddpm_sample(
+        self,
+        audio,
+        style_clip,
+        style_pad_mask,
+        output_dim,
+        flexibility=0.0,
+        ret_traj=False,
+        use_cf_guidance=False,
+        cfg_scale=2.0,
+    ):
+        """
+
+        Args:
+            audio (_type_): (B, L, W) or (B, L, W, C)
+            style_clip (_type_): (B, L_clipmax, C_face3d)
+            style_pad_mask : (B, L_clipmax)
+            pose_dim (_type_): int
+            flexibility (float, optional): _description_. Defaults to 0.0.
+            ret_traj (bool, optional): _description_. Defaults to False.
+
+
+        Returns:
+            _type_: (B, L, C_face)
+        """
+        batch_size, output_len = audio.shape[:2]
+        # batch_size = context.size(0)
+        context = {
+            "audio": audio,
+            "style_clip": style_clip,
+            "style_pad_mask": style_pad_mask,
+        }
+        if use_cf_guidance:
+            uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
+                batch_size, 1, 1
+            )
+            uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
+            context_double = {
+                "audio": torch.cat([audio] * 2, dim=0),
+                "style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
+                "style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
+            }
+
+        x_T = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
+        traj = {self.var_sched.num_steps: x_T}
+        for t in range(self.var_sched.num_steps, 0, -1):
+            alpha = self.var_sched.alphas[t]
+            alpha_bar = self.var_sched.alpha_bars[t]
+            alpha_bar_prev = self.var_sched.alpha_bars[t - 1]
+            sigma = self.var_sched.get_sigmas(t, flexibility)
+
+            z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
+            x_t = traj[t]
+            t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
+            if use_cf_guidance:
+                x_t_double = torch.cat([x_t] * 2, dim=0)
+                t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
+                cond_output, uncond_output = self.net(
+                    x_t_double, t=t_tensor_double, **context_double
+                ).chunk(2)
+                diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
+            else:
+                diff_output = self.net(x_t, t=t_tensor, **context)
+
+            if self.predict_what == "noise":
+                c0 = 1.0 / torch.sqrt(alpha)
+                c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
+                x_next = c0 * (x_t - c1 * diff_output) + sigma * z
+            elif self.predict_what == "x0":
+                d0 = torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar)
+                d1 = torch.sqrt(alpha_bar_prev) * (1 - alpha) / (1 - alpha_bar)
+                x_next = d0 * x_t + d1 * diff_output + sigma * z
+            traj[t - 1] = x_next.detach()
+            traj[t] = traj[t].cpu()
+            if not ret_traj:
+                del traj[t]
+
+        if ret_traj:
+            raise NotImplementedError
+            return traj
+        else:
+            latent_output = traj[0]
+            face3d_output = self._latent_to_face3d(latent_output)
+            return face3d_output
+
+
+if __name__ == "__main__":
+    from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
+
+    diffnet = DiffusionNet(
+        net=NoisePredictor(),
+        var_sched=VarianceSchedule(
+            num_steps=500, beta_1=1e-4, beta_T=0.02, mode="linear"
+        ),
+    )
+
+    import torch
+
+    gt_face3d = torch.randn(16, 64, 64)
+    audio = torch.randn(16, 64, 11)
+    style_clip = torch.randn(16, 256, 64)
+    style_pad_mask = torch.ones(16, 256)
+
+    context = {
+        "audio": audio,
+        "style_clip": style_clip,
+        "style_pad_mask": style_pad_mask,
+    }
+
+    loss = diffnet.get_loss(gt_face3d, context)
+
+    print("hello")

+ 131 - 0
core/networks/diffusion_util.py

@@ -0,0 +1,131 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Module
+from core.networks import get_network
+from core.utils import sinusoidal_embedding
+
+
+class VarianceSchedule(Module):
+    def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
+        super().__init__()
+        assert mode in ("linear",)
+        self.num_steps = num_steps
+        self.beta_1 = beta_1
+        self.beta_T = beta_T
+        self.mode = mode
+
+        if mode == "linear":
+            betas = torch.linspace(beta_1, beta_T, steps=num_steps)
+
+        betas = torch.cat([torch.zeros([1]), betas], dim=0)  # Padding
+
+        alphas = 1 - betas
+        log_alphas = torch.log(alphas)
+        for i in range(1, log_alphas.size(0)):  # 1 to T
+            log_alphas[i] += log_alphas[i - 1]
+        alpha_bars = log_alphas.exp()
+
+        sigmas_flex = torch.sqrt(betas)
+        sigmas_inflex = torch.zeros_like(sigmas_flex)
+        for i in range(1, sigmas_flex.size(0)):
+            sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
+                i
+            ]
+        sigmas_inflex = torch.sqrt(sigmas_inflex)
+
+        self.register_buffer("betas", betas)
+        self.register_buffer("alphas", alphas)
+        self.register_buffer("alpha_bars", alpha_bars)
+        self.register_buffer("sigmas_flex", sigmas_flex)
+        self.register_buffer("sigmas_inflex", sigmas_inflex)
+
+    def uniform_sample_t(self, batch_size):
+        ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
+        return ts.tolist()
+
+    def get_sigmas(self, t, flexibility):
+        assert 0 <= flexibility and flexibility <= 1
+        sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
+            1 - flexibility
+        )
+        return sigmas
+
+
+class NoisePredictor(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+
+        content_encoder_class = get_network(cfg.CONTENT_ENCODER_TYPE)
+        self.content_encoder = content_encoder_class(**cfg.CONTENT_ENCODER)
+
+        style_encoder_class = get_network(cfg.STYLE_ENCODER_TYPE)
+        cfg.defrost()
+        cfg.STYLE_ENCODER.input_dim = cfg.DATASET.FACE3D_DIM
+        cfg.freeze()
+        self.style_encoder = style_encoder_class(**cfg.STYLE_ENCODER)
+
+        decoder_class = get_network(cfg.DECODER_TYPE)
+        cfg.defrost()
+        cfg.DECODER.output_dim = cfg.DATASET.FACE3D_DIM
+        cfg.freeze()
+        self.decoder = decoder_class(**cfg.DECODER)
+
+        self.content_xt_to_decoder_input_wo_time = nn.Sequential(
+            nn.Linear(cfg.D_MODEL + cfg.DATASET.FACE3D_DIM, cfg.D_MODEL),
+            nn.ReLU(),
+            nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+            nn.ReLU(),
+            nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+        )
+
+        self.time_sinusoidal_dim = cfg.D_MODEL
+        self.time_embed_net = nn.Sequential(
+            nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+            nn.SiLU(),
+            nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+        )
+
+    def forward(self, x_t, t, audio, style_clip, style_pad_mask, ready_style_code=None):
+        """_summary_
+
+        Args:
+            x_t (_type_): (B, L, C_face)
+            t (_type_): (B,) dtype:float32
+            audio (_type_): (B, L, W)
+            style_clip (_type_): (B, L_clipmax, C_face3d)
+            style_pad_mask : (B, L_clipmax)
+            ready_style_code: (B, C_model)
+        Returns:
+            e_theta : (B, L, C_face)
+        """
+        W = audio.shape[2]
+        content = self.content_encoder(audio)
+        # (B, L, W, C_model)
+        x_t_expand = x_t.unsqueeze(2).repeat(1, 1, W, 1)
+        # (B, L, C_face) -> (B, L, W, C_face)
+        content_xt_concat = torch.cat((content, x_t_expand), dim=3)
+        # (B, L, W, C_model+C_face)
+        decoder_input_without_time = self.content_xt_to_decoder_input_wo_time(
+            content_xt_concat
+        )
+        # (B, L, W, C_model)
+
+        time_sinusoidal = sinusoidal_embedding(t, self.time_sinusoidal_dim)
+        # (B, C_embed)
+        time_embedding = self.time_embed_net(time_sinusoidal)
+        # (B, C_model)
+        B, C = time_embedding.shape
+        time_embed_expand = time_embedding.view(B, 1, 1, C)
+        decoder_input = decoder_input_without_time + time_embed_expand
+        # (B, L, W, C_model)
+
+        if ready_style_code is not None:
+            style_code = ready_style_code
+        else:
+            style_code = self.style_encoder(style_clip, style_pad_mask)
+        # (B, C_model)
+
+        e_theta = self.decoder(decoder_input, style_code)
+        # (B, L, C_face)
+        return e_theta

+ 240 - 0
core/networks/disentangle_decoder.py

@@ -0,0 +1,240 @@
+import torch
+from torch import nn
+
+from .transformer import (
+    PositionalEncoding,
+    TransformerDecoderLayer,
+    TransformerDecoder,
+)
+from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder
+from core.utils import _reset_parameters
+
+
+def get_decoder_network(
+    network_type,
+    d_model,
+    nhead,
+    dim_feedforward,
+    dropout,
+    activation,
+    normalize_before,
+    num_decoder_layers,
+    return_intermediate_dec,
+    dynamic_K,
+    dynamic_ratio,
+):
+    decoder = None
+    if network_type == "TransformerDecoder":
+        decoder_layer = TransformerDecoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        norm = nn.LayerNorm(d_model)
+        decoder = TransformerDecoder(
+            decoder_layer,
+            num_decoder_layers,
+            norm,
+            return_intermediate_dec,
+        )
+    elif network_type == "DynamicFCDecoder":
+        d_style = d_model
+        decoder_layer = DynamicFCDecoderLayer(
+            d_model,
+            nhead,
+            d_style,
+            dynamic_K,
+            dynamic_ratio,
+            dim_feedforward,
+            dropout,
+            activation,
+            normalize_before,
+        )
+        norm = nn.LayerNorm(d_model)
+        decoder = DynamicFCDecoder(
+            decoder_layer, num_decoder_layers, norm, return_intermediate_dec
+        )
+    elif network_type == "DynamicFCEncoder":
+        d_style = d_model
+        decoder_layer = DynamicFCEncoderLayer(
+            d_model,
+            nhead,
+            d_style,
+            dynamic_K,
+            dynamic_ratio,
+            dim_feedforward,
+            dropout,
+            activation,
+            normalize_before,
+        )
+        norm = nn.LayerNorm(d_model)
+        decoder = DynamicFCEncoder(decoder_layer, num_decoder_layers, norm)
+
+    else:
+        raise ValueError(f"Invalid network_type {network_type}")
+
+    return decoder
+
+
+class DisentangleDecoder(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_decoder_layers=3,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+        return_intermediate_dec=False,
+        pos_embed_len=80,
+        upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))),
+        lower_face3d_indices=tuple(range(19, 46)),
+        network_type="None",
+        dynamic_K=None,
+        dynamic_ratio=None,
+        **_,
+    ) -> None:
+        super().__init__()
+
+        self.upper_face3d_indices = upper_face3d_indices
+        self.lower_face3d_indices = lower_face3d_indices
+
+        # upper_decoder_layer = TransformerDecoderLayer(
+        #     d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        # )
+        # upper_decoder_norm = nn.LayerNorm(d_model)
+        # self.upper_decoder = TransformerDecoder(
+        #     upper_decoder_layer,
+        #     num_decoder_layers,
+        #     upper_decoder_norm,
+        #     return_intermediate=return_intermediate_dec,
+        # )
+        self.upper_decoder = get_decoder_network(
+            network_type,
+            d_model,
+            nhead,
+            dim_feedforward,
+            dropout,
+            activation,
+            normalize_before,
+            num_decoder_layers,
+            return_intermediate_dec,
+            dynamic_K,
+            dynamic_ratio,
+        )
+        _reset_parameters(self.upper_decoder)
+
+        # lower_decoder_layer = TransformerDecoderLayer(
+        #     d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        # )
+        # lower_decoder_norm = nn.LayerNorm(d_model)
+        # self.lower_decoder = TransformerDecoder(
+        #     lower_decoder_layer,
+        #     num_decoder_layers,
+        #     lower_decoder_norm,
+        #     return_intermediate=return_intermediate_dec,
+        # )
+        self.lower_decoder = get_decoder_network(
+            network_type,
+            d_model,
+            nhead,
+            dim_feedforward,
+            dropout,
+            activation,
+            normalize_before,
+            num_decoder_layers,
+            return_intermediate_dec,
+            dynamic_K,
+            dynamic_ratio,
+        )
+        _reset_parameters(self.lower_decoder)
+
+        self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+        tail_hidden_dim = d_model // 2
+        self.upper_tail_fc = nn.Sequential(
+            nn.Linear(d_model, tail_hidden_dim),
+            nn.ReLU(),
+            nn.Linear(tail_hidden_dim, tail_hidden_dim),
+            nn.ReLU(),
+            nn.Linear(tail_hidden_dim, len(upper_face3d_indices)),
+        )
+        self.lower_tail_fc = nn.Sequential(
+            nn.Linear(d_model, tail_hidden_dim),
+            nn.ReLU(),
+            nn.Linear(tail_hidden_dim, tail_hidden_dim),
+            nn.ReLU(),
+            nn.Linear(tail_hidden_dim, len(lower_face3d_indices)),
+        )
+
+    def forward(self, content, style_code):
+        """
+
+        Args:
+            content (_type_): (B, num_frames, window, C_dmodel)
+            style_code (_type_): (B, C_dmodel)
+
+        Returns:
+            face3d: (B, L_clip, C_3dmm)
+        """
+        B, N, W, C = content.shape
+        style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
+        style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
+        # (W, B*N, C)
+
+        content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
+        # (W, B*N, C)
+        tgt = torch.zeros_like(style)
+        pos_embed = self.pos_embed(W)
+        pos_embed = pos_embed.permute(1, 0, 2)
+
+        upper_face3d_feat = self.upper_decoder(
+            tgt, content, pos=pos_embed, query_pos=style
+        )[0]
+        # (W, B*N, C)
+        upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
+            :, :, W // 2, :
+        ]
+        # (B, N, C)
+        upper_face3d = self.upper_tail_fc(upper_face3d_feat)
+        # (B, N, C_exp)
+
+        lower_face3d_feat = self.lower_decoder(
+            tgt, content, pos=pos_embed, query_pos=style
+        )[0]
+        lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
+            :, :, W // 2, :
+        ]
+        lower_face3d = self.lower_tail_fc(lower_face3d_feat)
+        C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices)
+        face3d = torch.zeros(B, N, C_exp).to(upper_face3d)
+        face3d[:, :, self.upper_face3d_indices] = upper_face3d
+        face3d[:, :, self.lower_face3d_indices] = lower_face3d
+        return face3d
+
+
+if __name__ == "__main__":
+    import sys
+
+    sys.path.append("/home/mayifeng/Research/styleTH")
+
+    from configs.default import get_cfg_defaults
+
+    cfg = get_cfg_defaults()
+    cfg.merge_from_file("configs/styleTH_unpair_lsfm_emotion.yaml")
+    cfg.freeze()
+
+    # content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
+
+    # dummy_audio = torch.randint(0, 41, (5, 64, 11))
+    # dummy_content = content_encoder(dummy_audio)
+
+    # style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
+    # dummy_face3d_seq = torch.randn(5, 64, 64)
+    # dummy_style_code = style_encoder(dummy_face3d_seq)
+
+    decoder = DisentangleDecoder(**cfg.DECODER)
+    dummy_content = torch.randn(5, 64, 11, 256)
+    dummy_style = torch.randn(5, 256)
+    dummy_output = decoder(dummy_content, dummy_style)
+
+    print("hello")

+ 156 - 0
core/networks/dynamic_conv.py

@@ -0,0 +1,156 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Attention(nn.Module):
+    def __init__(self, cond_planes, ratio, K, temperature=30, init_weight=True):
+        super().__init__()
+        # self.avgpool = nn.AdaptiveAvgPool2d(1)
+        self.temprature = temperature
+        assert cond_planes > ratio
+        hidden_planes = cond_planes // ratio
+        self.net = nn.Sequential(
+            nn.Conv2d(cond_planes, hidden_planes, kernel_size=1, bias=False),
+            nn.ReLU(),
+            nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False),
+        )
+
+        if init_weight:
+            self._initialize_weights()
+
+    def update_temprature(self):
+        if self.temprature > 1:
+            self.temprature -= 1
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            if isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, cond):
+        """
+
+        Args:
+            cond (_type_): (B, C_style)
+
+        Returns:
+            _type_: (B, K)
+        """
+
+        # att = self.avgpool(cond)  # bs,dim,1,1
+        att = cond.view(cond.shape[0], cond.shape[1], 1, 1)
+        att = self.net(att).view(cond.shape[0], -1)  # bs,K
+        return F.softmax(att / self.temprature, -1)
+
+
+class DynamicConv(nn.Module):
+    def __init__(
+        self,
+        in_planes,
+        out_planes,
+        cond_planes,
+        kernel_size,
+        stride,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=True,
+        K=4,
+        temperature=30,
+        ratio=4,
+        init_weight=True,
+    ):
+        super().__init__()
+        self.in_planes = in_planes
+        self.out_planes = out_planes
+        self.cond_planes = cond_planes
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.groups = groups
+        self.bias = bias
+        self.K = K
+        self.init_weight = init_weight
+        self.attention = Attention(
+            cond_planes=cond_planes, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight
+        )
+
+        self.weight = nn.Parameter(
+            torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size), requires_grad=True
+        )
+        if bias:
+            self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
+        else:
+            self.bias = None
+
+        if self.init_weight:
+            self._initialize_weights()
+
+    def _initialize_weights(self):
+        for i in range(self.K):
+            nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
+            if self.bias is not None:
+                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
+                if fan_in != 0:
+                    bound = 1 / math.sqrt(fan_in)
+                    nn.init.uniform_(self.bias, -bound, bound)
+
+    def forward(self, x, cond):
+        """
+
+        Args:
+            x (_type_): (B, C_in, L, 1)
+            cond (_type_): (B, C_style)
+
+        Returns:
+            _type_: (B, C_out, L, 1)
+        """
+        bs, in_planels, h, w = x.shape
+        softmax_att = self.attention(cond)  # bs,K
+        x = x.view(1, -1, h, w)
+        weight = self.weight.view(self.K, -1)  # K,-1
+        aggregate_weight = torch.mm(softmax_att, weight).view(
+            bs * self.out_planes, self.in_planes // self.groups, self.kernel_size, self.kernel_size
+        )  # bs*out_p,in_p,k,k
+
+        if self.bias is not None:
+            bias = self.bias.view(self.K, -1)  # K,out_p
+            aggregate_bias = torch.mm(softmax_att, bias).view(-1)  # bs*out_p
+            output = F.conv2d(
+                x, # 1, bs*in_p, L, 1
+                weight=aggregate_weight,
+                bias=aggregate_bias,
+                stride=self.stride,
+                padding=self.padding,
+                groups=self.groups * bs,
+                dilation=self.dilation,
+            )
+        else:
+            output = F.conv2d(
+                x,
+                weight=aggregate_weight,
+                bias=None,
+                stride=self.stride,
+                padding=self.padding,
+                groups=self.groups * bs,
+                dilation=self.dilation,
+            )
+
+        output = output.view(bs, self.out_planes, h, w)
+        return output
+
+
+if __name__ == "__main__":
+    input = torch.randn(3, 32, 64, 64)
+    m = DynamicConv(in_planes=32, out_planes=64, kernel_size=3, stride=1, padding=1, bias=True)
+    out = m(input)
+    print(out.shape)

+ 178 - 0
core/networks/dynamic_fc_decoder.py

@@ -0,0 +1,178 @@
+import torch.nn as nn
+import torch
+
+from core.networks.transformer import _get_activation_fn, _get_clones
+from core.networks.dynamic_linear import DynamicLinear
+
+
+class DynamicFCDecoderLayer(nn.Module):
+    def __init__(
+        self,
+        d_model,
+        nhead,
+        d_style,
+        dynamic_K,
+        dynamic_ratio,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+    ):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        # self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.linear1 = DynamicLinear(d_model, dim_feedforward, d_style, K=dynamic_K, ratio=dynamic_ratio)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+        # self.linear2 = DynamicLinear(dim_feedforward, d_model, d_style, K=dynamic_K, ratio=dynamic_ratio)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(
+        self,
+        tgt,
+        memory,
+        style,
+        tgt_mask=None,
+        memory_mask=None,
+        tgt_key_padding_mask=None,
+        memory_key_padding_mask=None,
+        pos=None,
+        query_pos=None,
+    ):
+        # q = k = self.with_pos_embed(tgt, query_pos)
+        tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+        tgt2 = self.multihead_attn(
+            query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
+        )[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+        # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))), style)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))))
+        tgt = tgt + self.dropout3(tgt2)
+        tgt = self.norm3(tgt)
+        return tgt
+
+    # def forward_pre(
+    #     self,
+    #     tgt,
+    #     memory,
+    #     tgt_mask=None,
+    #     memory_mask=None,
+    #     tgt_key_padding_mask=None,
+    #     memory_key_padding_mask=None,
+    #     pos=None,
+    #     query_pos=None,
+    # ):
+    #     tgt2 = self.norm1(tgt)
+    #     # q = k = self.with_pos_embed(tgt2, query_pos)
+    #     tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+    #     tgt = tgt + self.dropout1(tgt2)
+    #     tgt2 = self.norm2(tgt)
+    #     tgt2 = self.multihead_attn(
+    #         query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
+    #     )[0]
+    #     tgt = tgt + self.dropout2(tgt2)
+    #     tgt2 = self.norm3(tgt)
+    #     tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+    #     tgt = tgt + self.dropout3(tgt2)
+    #     return tgt
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        style,
+        tgt_mask=None,
+        memory_mask=None,
+        tgt_key_padding_mask=None,
+        memory_key_padding_mask=None,
+        pos=None,
+        query_pos=None,
+    ):
+        if self.normalize_before:
+            raise NotImplementedError
+            # return self.forward_pre(
+            #     tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+            # )
+        return self.forward_post(
+            tgt, memory, style, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+        )
+
+
+class DynamicFCDecoder(nn.Module):
+    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+        super().__init__()
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+        self.return_intermediate = return_intermediate
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        tgt_mask=None,
+        memory_mask=None,
+        tgt_key_padding_mask=None,
+        memory_key_padding_mask=None,
+        pos=None,
+        query_pos=None,
+    ):
+        style = query_pos[0]
+        # (B*N, C)
+        output = tgt + pos + query_pos
+
+        intermediate = []
+
+        for layer in self.layers:
+            output = layer(
+                output,
+                memory,
+                style,
+                tgt_mask=tgt_mask,
+                memory_mask=memory_mask,
+                tgt_key_padding_mask=tgt_key_padding_mask,
+                memory_key_padding_mask=memory_key_padding_mask,
+                pos=pos,
+                query_pos=query_pos,
+            )
+            if self.return_intermediate:
+                intermediate.append(self.norm(output))
+
+        if self.norm is not None:
+            output = self.norm(output)
+            if self.return_intermediate:
+                intermediate.pop()
+                intermediate.append(output)
+
+        if self.return_intermediate:
+            return torch.stack(intermediate)
+
+        return output.unsqueeze(0)
+
+
+if __name__ == "__main__":
+    query = torch.randn(11, 1024, 256)
+    content = torch.randn(11, 1024, 256)
+    style = torch.randn(1024, 256)
+    pos = torch.randn(11, 1, 256)
+    m = DynamicFCDecoderLayer(256, 4, 256, 4, 4, 1024)
+
+    out = m(query, content, style, pos=pos)
+    print(out.shape)

+ 50 - 0
core/networks/dynamic_linear.py

@@ -0,0 +1,50 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from core.networks.dynamic_conv import DynamicConv
+
+
+class DynamicLinear(nn.Module):
+    def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
+        super().__init__()
+
+        self.dynamic_conv = DynamicConv(
+            in_planes,
+            out_planes,
+            cond_planes,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+            K=K,
+            ratio=ratio,
+            temperature=temperature,
+            init_weight=init_weight,
+        )
+
+    def forward(self, x, cond):
+        """
+
+        Args:
+            x (_type_): (L, B, C_in)
+            cond (_type_): (B, C_style)
+
+        Returns:
+            _type_: (L, B, C_out)
+        """
+        x = x.permute(1, 2, 0).unsqueeze(-1)
+        out = self.dynamic_conv(x, cond)
+        # (B, C_out, L, 1)
+        out = out.squeeze().permute(2, 0, 1)
+        return out
+
+
+if __name__ == "__main__":
+    input = torch.randn(11, 1024, 255)
+    cond = torch.randn(1024, 256)
+    m = DynamicLinear(255, 1000, 256, K=7, temperature=5, ratio=8)
+    out = m(input, cond)
+    print(out.shape)

+ 309 - 0
core/networks/generator.py

@@ -0,0 +1,309 @@
+import torch
+from torch import nn
+
+from .transformer import (
+    TransformerEncoder,
+    TransformerEncoderLayer,
+    PositionalEncoding,
+    TransformerDecoderLayer,
+    TransformerDecoder,
+)
+from core.utils import _reset_parameters
+from core.networks.self_attention_pooling import SelfAttentionPooling
+
+
+# class ContentEncoder(nn.Module):
+#     def __init__(
+#         self,
+#         d_model=512,
+#         nhead=8,
+#         num_encoder_layers=6,
+#         dim_feedforward=2048,
+#         dropout=0.1,
+#         activation="relu",
+#         normalize_before=False,
+#         pos_embed_len=80,
+#         ph_embed_dim=128,
+#     ):
+#         super().__init__()
+
+#         encoder_layer = TransformerEncoderLayer(
+#             d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+#         )
+#         encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+#         self.encoder = TransformerEncoder(
+#             encoder_layer, num_encoder_layers, encoder_norm
+#         )
+
+#         _reset_parameters(self.encoder)
+
+#         self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+#         self.ph_embedding = nn.Embedding(41, ph_embed_dim)
+#         self.increase_embed_dim = nn.Linear(ph_embed_dim, d_model)
+
+#     def forward(self, x):
+#         """
+
+#         Args:
+#             x (_type_): (B, num_frames, window)
+
+#         Returns:
+#             content: (B, num_frames, window, C_dmodel)
+#         """
+#         x_embedding = self.ph_embedding(x)
+#         x_embedding = self.increase_embed_dim(x_embedding)
+#         # (B, N, W, C)
+#         B, N, W, C = x_embedding.shape
+#         x_embedding = x_embedding.reshape(B * N, W, C)
+#         x_embedding = x_embedding.permute(1, 0, 2)
+#         # (W, B*N, C)
+
+#         pos = self.pos_embed(W)
+#         pos = pos.permute(1, 0, 2)
+#         # (W, 1, C)
+
+#         content = self.encoder(x_embedding, pos=pos)
+#         # (W, B*N, C)
+#         content = content.permute(1, 0, 2).reshape(B, N, W, C)
+#         # (B, N, W, C)
+
+#         return content
+
+
+class ContentW2VEncoder(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_encoder_layers=6,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+        pos_embed_len=80,
+        ph_embed_dim=128,
+    ):
+        super().__init__()
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+        self.encoder = TransformerEncoder(
+            encoder_layer, num_encoder_layers, encoder_norm
+        )
+
+        _reset_parameters(self.encoder)
+
+        self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+        self.increase_embed_dim = nn.Linear(1024, d_model)
+
+    def forward(self, x):
+        """
+        Args:
+            x (_type_): (B, num_frames, window, C_wav2vec)
+
+        Returns:
+            content: (B, num_frames, window, C_dmodel)
+        """
+        x_embedding = self.increase_embed_dim(
+            x
+        )  # [16, 64, 11, 1024] -> [16, 64, 11, 256]
+        # (B, N, W, C)
+        B, N, W, C = x_embedding.shape
+        x_embedding = x_embedding.reshape(B * N, W, C)
+        x_embedding = x_embedding.permute(1, 0, 2)  # [11, 1024, 256]
+        # (W, B*N, C)
+
+        pos = self.pos_embed(W)
+        pos = pos.permute(1, 0, 2)  # [11, 1, 256]
+        # (W, 1, C)
+
+        content = self.encoder(x_embedding, pos=pos)  # [11, 1024, 256]
+        # (W, B*N, C)
+        content = content.permute(1, 0, 2).reshape(B, N, W, C)
+        # (B, N, W, C)
+
+        return content
+
+
+class StyleEncoder(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_encoder_layers=6,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+        pos_embed_len=80,
+        input_dim=128,
+        aggregate_method="average",
+    ):
+        super().__init__()
+        encoder_layer = TransformerEncoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+        self.encoder = TransformerEncoder(
+            encoder_layer, num_encoder_layers, encoder_norm
+        )
+        _reset_parameters(self.encoder)
+
+        self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+        self.increase_embed_dim = nn.Linear(input_dim, d_model)
+
+        self.aggregate_method = None
+        if aggregate_method == "self_attention_pooling":
+            self.aggregate_method = SelfAttentionPooling(d_model)
+        elif aggregate_method == "average":
+            pass
+        else:
+            raise ValueError(f"Invalid aggregate method {aggregate_method}")
+
+    def forward(self, x, pad_mask=None):
+        """
+
+        Args:
+            x (_type_): (B, num_frames(L), C_exp)
+            pad_mask: (B, num_frames)
+
+        Returns:
+            style_code: (B, C_model)
+        """
+        x = self.increase_embed_dim(x)
+        # (B, L, C)
+        x = x.permute(1, 0, 2)
+        # (L, B, C)
+
+        pos = self.pos_embed(x.shape[0])
+        pos = pos.permute(1, 0, 2)
+        # (L, 1, C)
+
+        style = self.encoder(x, pos=pos, src_key_padding_mask=pad_mask)
+        # (L, B, C)
+
+        if self.aggregate_method is not None:
+            permute_style = style.permute(1, 0, 2)
+            # (B, L, C)
+            style_code = self.aggregate_method(permute_style, pad_mask)
+            return style_code
+
+        if pad_mask is None:
+            style = style.permute(1, 2, 0)
+            # (B, C, L)
+            style_code = style.mean(2)
+            # (B, C)
+        else:
+            permute_style = style.permute(1, 0, 2)
+            # (B, L, C)
+            permute_style[pad_mask] = 0
+            sum_style_code = permute_style.sum(dim=1)
+            # (B, C)
+            valid_token_num = (~pad_mask).sum(dim=1).unsqueeze(-1)
+            # (B, 1)
+            style_code = sum_style_code / valid_token_num
+            # (B, C)
+
+        return style_code
+
+
+class Decoder(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_decoder_layers=3,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+        return_intermediate_dec=False,
+        pos_embed_len=80,
+        output_dim=64,
+        **_,
+    ) -> None:
+        super().__init__()
+
+        decoder_layer = TransformerDecoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        decoder_norm = nn.LayerNorm(d_model)
+        self.decoder = TransformerDecoder(
+            decoder_layer,
+            num_decoder_layers,
+            decoder_norm,
+            return_intermediate=return_intermediate_dec,
+        )
+        _reset_parameters(self.decoder)
+
+        self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+        tail_hidden_dim = d_model // 2
+        self.tail_fc = nn.Sequential(
+            nn.Linear(d_model, tail_hidden_dim),
+            nn.ReLU(),
+            nn.Linear(tail_hidden_dim, tail_hidden_dim),
+            nn.ReLU(),
+            nn.Linear(tail_hidden_dim, output_dim),
+        )
+
+    def forward(self, content, style_code):
+        """
+
+        Args:
+            content (_type_): (B, num_frames, window, C_dmodel)
+            style_code (_type_): (B, C_dmodel)
+
+        Returns:
+            face3d: (B, num_frames, C_3dmm)
+        """
+        B, N, W, C = content.shape
+        style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
+        style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
+        # (W, B*N, C)
+
+        content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
+        # (W, B*N, C)
+        tgt = torch.zeros_like(style)
+        pos_embed = self.pos_embed(W)
+        pos_embed = pos_embed.permute(1, 0, 2)
+        face3d_feat = self.decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
+        # (W, B*N, C)
+        face3d_feat = face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
+        # (B, N, C)
+        face3d = self.tail_fc(face3d_feat)
+        # (B, N, C_exp)
+        return face3d
+
+
+if __name__ == "__main__":
+    import sys
+
+    sys.path.append("/home/mayifeng/Research/styleTH")
+
+    from configs.default import get_cfg_defaults
+
+    cfg = get_cfg_defaults()
+    cfg.merge_from_file("configs/styleTH_bp.yaml")
+    cfg.freeze()
+
+    # content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
+
+    # dummy_audio = torch.randint(0, 41, (5, 64, 11))
+    # dummy_content = content_encoder(dummy_audio)
+
+    # style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
+    # dummy_face3d_seq = torch.randn(5, 64, 64)
+    # dummy_style_code = style_encoder(dummy_face3d_seq)
+
+    decoder = Decoder(**cfg.DECODER)
+    dummy_content = torch.randn(5, 64, 11, 512)
+    dummy_style = torch.randn(5, 512)
+    dummy_output = decoder(dummy_content, dummy_style)
+
+    print("hello")

+ 51 - 0
core/networks/mish.py

@@ -0,0 +1,51 @@
+"""
+Applies the mish function element-wise:
+mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+"""
+
+# import pytorch
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+@torch.jit.script
+def mish(input):
+    """
+    Applies the mish function element-wise:
+    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+    See additional documentation for mish class.
+    """
+    return input * torch.tanh(F.softplus(input))
+
+class Mish(nn.Module):
+    """
+    Applies the mish function element-wise:
+    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+
+    Shape:
+        - Input: (N, *) where * means, any number of additional
+          dimensions
+        - Output: (N, *), same shape as the input
+
+    Examples:
+        >>> m = Mish()
+        >>> input = torch.randn(2)
+        >>> output = m(input)
+
+    Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
+    """
+
+    def __init__(self):
+        """
+        Init method.
+        """
+        super().__init__()
+
+    def forward(self, input):
+        """
+        Forward pass of the function.
+        """
+        if torch.__version__ >= "1.9":
+            return F.mish(input)
+        else:
+            return mish(input)

+ 53 - 0
core/networks/self_attention_pooling.py

@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+from core.networks.mish import Mish
+
+
+class SelfAttentionPooling(nn.Module):
+    """
+    Implementation of SelfAttentionPooling
+    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
+    https://arxiv.org/pdf/2008.01077v1.pdf
+    """
+
+    def __init__(self, input_dim):
+        super(SelfAttentionPooling, self).__init__()
+        self.W = nn.Sequential(nn.Linear(input_dim, input_dim), Mish(), nn.Linear(input_dim, 1))
+        self.softmax = nn.functional.softmax
+
+    def forward(self, batch_rep, att_mask=None):
+        """
+        N: batch size, T: sequence length, H: Hidden dimension
+        input:
+            batch_rep : size (N, T, H)
+        attention_weight:
+            att_w : size (N, T, 1)
+        att_mask:
+            att_mask: size (N, T): if True, mask this item.
+        return:
+            utter_rep: size (N, H)
+        """
+
+        att_logits = self.W(batch_rep).squeeze(-1)
+        # (N, T)
+        if att_mask is not None:
+            att_mask_logits = att_mask.to(dtype=batch_rep.dtype) * -100000.0
+            # (N, T)
+            att_logits = att_mask_logits + att_logits
+
+        att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
+        utter_rep = torch.sum(batch_rep * att_w, dim=1)
+
+        return utter_rep
+
+
+if __name__ == "__main__":
+    batch = torch.randn(8, 64, 256)
+    self_attn_pool = SelfAttentionPooling(256)
+    att_mask = torch.zeros(8, 64)
+    att_mask[:, 60:] = 1
+    att_mask = att_mask.to(torch.bool)
+    output = self_attn_pool(batch, att_mask)
+    # (8, 256)
+
+    print("hello")

+ 293 - 0
core/networks/transformer.py

@@ -0,0 +1,293 @@
+import torch.nn as nn
+import torch
+import numpy as np
+import torch.nn.functional as F
+import copy
+
+
+class PositionalEncoding(nn.Module):
+
+    def __init__(self, d_hid, n_position=200):
+        super(PositionalEncoding, self).__init__()
+
+        # Not a parameter
+        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
+
+    def _get_sinusoid_encoding_table(self, n_position, d_hid):
+        ''' Sinusoid position encoding table '''
+        # TODO: make it with torch instead of numpy
+
+        def get_position_angle_vec(position):
+            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
+        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
+
+        return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+    def forward(self, winsize):
+        return self.pos_table[:, :winsize].clone().detach()
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+def _get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+class Transformer(nn.Module):
+
+    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
+                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
+                 activation="relu", normalize_before=False,
+                 return_intermediate_dec=True):
+        super().__init__()
+
+        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
+                                                dropout, activation, normalize_before)
+        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
+                                                dropout, activation, normalize_before)
+        decoder_norm = nn.LayerNorm(d_model)
+        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
+                                          return_intermediate=return_intermediate_dec)
+
+        self._reset_parameters()
+
+        self.d_model = d_model
+        self.nhead = nhead
+
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def forward(self,opt, src, query_embed, pos_embed):
+        # flatten NxCxHxW to HWxNxC
+
+        src = src.permute(1, 0, 2)
+        pos_embed = pos_embed.permute(1, 0, 2)
+        query_embed = query_embed.permute(1, 0, 2)
+
+        tgt = torch.zeros_like(query_embed)
+        memory = self.encoder(src, pos=pos_embed)
+
+        hs = self.decoder(tgt, memory,
+                          pos=pos_embed, query_pos=query_embed)
+        return hs
+
+
+class TransformerEncoder(nn.Module):
+
+    def __init__(self, encoder_layer, num_layers, norm=None):
+        super().__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+
+    def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):
+        output = src+pos
+
+        for layer in self.layers:
+            output = layer(output, src_mask=mask,
+                           src_key_padding_mask=src_key_padding_mask, pos=pos)
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output
+
+
+class TransformerDecoder(nn.Module):
+
+    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+        super().__init__()
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+        self.return_intermediate = return_intermediate
+
+    def forward(self, tgt, memory,  tgt_mask = None,  memory_mask = None, tgt_key_padding_mask = None,
+                memory_key_padding_mask = None,
+                pos = None,
+                query_pos = None):
+        output = tgt+pos+query_pos
+
+        intermediate = []
+
+        for layer in self.layers:
+            output = layer(output, memory, tgt_mask=tgt_mask,
+                           memory_mask=memory_mask,
+                           tgt_key_padding_mask=tgt_key_padding_mask,
+                           memory_key_padding_mask=memory_key_padding_mask,
+                           pos=pos, query_pos=query_pos)
+            if self.return_intermediate:
+                intermediate.append(self.norm(output))
+
+        if self.norm is not None:
+            output = self.norm(output)
+            if self.return_intermediate:
+                intermediate.pop()
+                intermediate.append(output)
+
+        if self.return_intermediate:
+            return torch.stack(intermediate)
+
+        return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+
+    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self,
+                     src,
+                     src_mask = None,
+                     src_key_padding_mask = None,
+                     pos = None):
+        # q = k = self.with_pos_embed(src, pos)
+        src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,
+                              key_padding_mask=src_key_padding_mask)[0]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        return src
+
+    def forward_pre(self, src,
+                    src_mask = None,
+                    src_key_padding_mask = None,
+                    pos = None):
+        src2 = self.norm1(src)
+        # q = k = self.with_pos_embed(src2, pos)
+        src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,
+                              key_padding_mask=src_key_padding_mask)[0]
+        src = src + self.dropout1(src2)
+        src2 = self.norm2(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+        src = src + self.dropout2(src2)
+        return src
+
+    def forward(self, src,
+                src_mask = None,
+                src_key_padding_mask = None,
+                pos = None):
+        if self.normalize_before:
+            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+        return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+
+    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self, tgt, memory,
+                     tgt_mask = None,
+                     memory_mask = None,
+                     tgt_key_padding_mask = None,
+                     memory_key_padding_mask = None,
+                     pos = None,
+                     query_pos = None):
+        # q = k = self.with_pos_embed(tgt, query_pos)
+        tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+        tgt2 = self.multihead_attn(query=tgt,
+                                   key=memory,
+                                   value=memory, attn_mask=memory_mask,
+                                   key_padding_mask=memory_key_padding_mask)[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout3(tgt2)
+        tgt = self.norm3(tgt)
+        return tgt
+
+    def forward_pre(self, tgt, memory,
+                    tgt_mask = None,
+                    memory_mask = None,
+                    tgt_key_padding_mask = None,
+                    memory_key_padding_mask = None,
+                    pos = None,
+                    query_pos = None):
+        tgt2 = self.norm1(tgt)
+        # q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt2 = self.norm2(tgt)
+        tgt2 = self.multihead_attn(query=tgt2,
+                                   key=memory,
+                                   value=memory, attn_mask=memory_mask,
+                                   key_padding_mask=memory_key_padding_mask)[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt2 = self.norm3(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout3(tgt2)
+        return tgt
+
+    def forward(self, tgt, memory,
+                tgt_mask = None,
+                memory_mask = None,
+                tgt_key_padding_mask = None,
+                memory_key_padding_mask = None,
+                pos = None,
+                query_pos = None):
+        if self.normalize_before:
+            return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
+                                    tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+        return self.forward_post(tgt, memory, tgt_mask, memory_mask,
+                                 tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+
+
+

+ 456 - 0
core/utils.py

@@ -0,0 +1,456 @@
+import os
+import argparse
+from collections import defaultdict
+import logging
+import pickle
+import json
+
+import numpy as np
+import torch
+from torch import nn
+from scipy.io import loadmat
+
+from configs.default import get_cfg_defaults
+import dlib
+import cv2
+
+
+def _reset_parameters(model):
+    for p in model.parameters():
+        if p.dim() > 1:
+            nn.init.xavier_uniform_(p)
+
+
+def get_video_style(video_name, style_type):
+    person_id, direction, emotion, level, *_ = video_name.split("_")
+    if style_type == "id_dir_emo_level":
+        style = "_".join([person_id, direction, emotion, level])
+    elif style_type == "emotion":
+        style = emotion
+    elif style_type == "id":
+        style = person_id
+    else:
+        raise ValueError("Unknown style type")
+
+    return style
+
+
+def get_style_video_lists(video_list, style_type):
+    style2video_list = defaultdict(list)
+    for video in video_list:
+        style = get_video_style(video, style_type)
+        style2video_list[style].append(video)
+
+    return style2video_list
+
+
+def get_face3d_clip(
+    video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
+):
+    """_summary_
+
+    Args:
+        video_name (_type_): _description_
+        video_root_dir (_type_): _description_
+        num_frames (_type_): _description_
+        start_idx (_type_): "random" , middle, int
+        dtype (_type_, optional): _description_. Defaults to torch.float32.
+
+    Raises:
+        ValueError: _description_
+        ValueError: _description_
+
+    Returns:
+        _type_: _description_
+    """
+    video_path = os.path.join(video_root_dir, video_name)
+    if video_path[-3:] == "mat":
+        face3d_all = loadmat(video_path)["coeff"]
+        face3d_exp = face3d_all[:, 80:144]  # expression 3DMM range
+    elif video_path[-3:] == "txt":
+        face3d_exp = np.loadtxt(video_path)
+    else:
+        raise ValueError("Invalid 3DMM file extension")
+
+    length = face3d_exp.shape[0]
+    clip_num_frames = num_frames
+    if start_idx == "random":
+        clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
+    elif start_idx == "middle":
+        clip_start_idx = (length - clip_num_frames + 1) // 2
+    elif isinstance(start_idx, int):
+        clip_start_idx = start_idx
+    else:
+        raise ValueError(f"Invalid start_idx {start_idx}")
+
+    face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
+    face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
+
+    return face3d_clip
+
+
+def get_video_style_clip(
+    video_name,
+    video_root_dir,
+    style_max_len,
+    start_idx="random",
+    dtype=torch.float32,
+    return_start_idx=False,
+):
+    video_path = os.path.join(video_root_dir, video_name)
+    if video_path[-3:] == "mat":
+        face3d_all = loadmat(video_path)["coeff"]
+        face3d_exp = face3d_all[:, 80:144]  # expression 3DMM range
+    elif video_path[-3:] == "txt":
+        face3d_exp = np.loadtxt(video_path)
+    else:
+        raise ValueError("Invalid 3DMM file extension")
+
+    face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
+
+    length = face3d_exp.shape[0]
+    if length >= style_max_len:
+        clip_num_frames = style_max_len
+        if start_idx == "random":
+            clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
+        elif start_idx == "middle":
+            clip_start_idx = (length - clip_num_frames + 1) // 2
+        elif isinstance(start_idx, int):
+            clip_start_idx = start_idx
+        else:
+            raise ValueError(f"Invalid start_idx {start_idx}")
+
+        face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
+        pad_mask = torch.tensor([False] * style_max_len)
+    else:
+        clip_start_idx = None
+        padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
+        face3d_clip = torch.cat((face3d_exp, padding), dim=0)
+        pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
+
+    if return_start_idx:
+        return face3d_clip, pad_mask, clip_start_idx
+    else:
+        return face3d_clip, pad_mask
+
+
+def get_video_style_clip_from_np(
+    face3d_exp,
+    style_max_len,
+    start_idx="random",
+    dtype=torch.float32,
+    return_start_idx=False,
+):
+    face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
+
+    length = face3d_exp.shape[0]
+    if length >= style_max_len:
+        clip_num_frames = style_max_len
+        if start_idx == "random":
+            clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
+        elif start_idx == "middle":
+            clip_start_idx = (length - clip_num_frames + 1) // 2
+        elif isinstance(start_idx, int):
+            clip_start_idx = start_idx
+        else:
+            raise ValueError(f"Invalid start_idx {start_idx}")
+
+        face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
+        pad_mask = torch.tensor([False] * style_max_len)
+    else:
+        clip_start_idx = None
+        padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
+        face3d_clip = torch.cat((face3d_exp, padding), dim=0)
+        pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
+
+    if return_start_idx:
+        return face3d_clip, pad_mask, clip_start_idx
+    else:
+        return face3d_clip, pad_mask
+
+
+def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
+    """
+
+    Args:
+        audio_feat (np.ndarray): (N, 1024)
+        start_idx (_type_): _description_
+        num_frames (_type_): _description_
+    """
+    center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
+    audio_window_list = []
+    padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
+    for center_idx in center_idx_list:
+        cur_audio_window = []
+        for i in range(center_idx - win_size, center_idx + win_size + 1):
+            if i < 0:
+                cur_audio_window.append(padding)
+            elif i >= len(audio_feat):
+                cur_audio_window.append(padding)
+            else:
+                cur_audio_window.append(audio_feat[i])
+        cur_audio_win_array = np.stack(cur_audio_window, axis=0)
+        audio_window_list.append(cur_audio_win_array)
+
+    audio_window_array = np.stack(audio_window_list, axis=0)
+    return audio_window_array
+
+
+def setup_config():
+    parser = argparse.ArgumentParser(description="voice2pose main program")
+    parser.add_argument(
+        "--config_file", default="", metavar="FILE", help="path to config file"
+    )
+    parser.add_argument(
+        "--resume_from", type=str, default=None, help="the checkpoint to resume from"
+    )
+    parser.add_argument(
+        "--test_only", action="store_true", help="perform testing and evaluation only"
+    )
+    parser.add_argument(
+        "--demo_input", type=str, default=None, help="path to input for demo"
+    )
+    parser.add_argument(
+        "--checkpoint", type=str, default=None, help="the checkpoint to test with"
+    )
+    parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+    parser.add_argument(
+        "--local_rank",
+        type=int,
+        help="local rank for DistributedDataParallel",
+    )
+    parser.add_argument(
+        "--master_port",
+        type=str,
+        default="12345",
+    )
+    parser.add_argument(
+        "--max_audio_len",
+        type=int,
+        default=450,
+        help="max_audio_len for inference",
+    )
+    parser.add_argument(
+        "--ddim_num_step",
+        type=int,
+        default=10,
+    )
+    parser.add_argument(
+        "--inference_seed",
+        type=int,
+        default=1,
+    )
+    parser.add_argument(
+        "--inference_sample_method",
+        type=str,
+        default="ddim",
+    )
+    args = parser.parse_args()
+
+    cfg = get_cfg_defaults()
+    cfg.merge_from_file(args.config_file)
+    cfg.merge_from_list(args.opts)
+    cfg.freeze()
+    return args, cfg
+
+
+def setup_logger(base_path, exp_name):
+    rootLogger = logging.getLogger()
+    rootLogger.setLevel(logging.INFO)
+
+    logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
+
+    log_path = "{0}/{1}.log".format(base_path, exp_name)
+    fileHandler = logging.FileHandler(log_path)
+    fileHandler.setFormatter(logFormatter)
+    rootLogger.addHandler(fileHandler)
+
+    consoleHandler = logging.StreamHandler()
+    consoleHandler.setFormatter(logFormatter)
+    rootLogger.addHandler(consoleHandler)
+    rootLogger.handlers[0].setLevel(logging.INFO)
+
+    logging.info("log path: %s" % log_path)
+
+
+def cosine_loss(a, v, y, logloss=nn.BCELoss()):
+    d = nn.functional.cosine_similarity(a, v)
+    loss = logloss(d.unsqueeze(1), y)
+    return loss
+
+
+def get_pose_params(mat_path):
+    """Get pose parameters from mat file
+
+    Args:
+        mat_path (str): path of mat file
+
+    Returns:
+        pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
+    """
+    mat_dict = loadmat(mat_path)
+
+    np_3dmm = mat_dict["coeff"]
+    angles = np_3dmm[:, 224:227]
+    translations = np_3dmm[:, 254:257]
+
+    np_trans_params = mat_dict["transform_params"]
+    crop = np_trans_params[:, -3:]
+
+    pose_params = np.concatenate((angles, translations, crop), axis=1)
+
+    return pose_params
+
+
+def sinusoidal_embedding(timesteps, dim):
+    """
+
+    Args:
+        timesteps (_type_): (B,)
+        dim (_type_): (C_embed)
+
+    Returns:
+        _type_: (B, C_embed)
+    """
+    # check input
+    half = dim // 2
+    timesteps = timesteps.float()
+
+    # compute sinusoidal embedding
+    sinusoid = torch.outer(
+        timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
+    )
+    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+    if dim % 2 != 0:
+        x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
+    return x
+
+
+def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
+    """
+
+    Args:
+        audio_feat (np.ndarray): (250, 1024)
+        start_idx (_type_): _description_
+        num_frames (_type_): _description_
+    """
+    center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
+    audio_window_list = []
+    padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
+    for center_idx in center_idx_list:
+        cur_audio_window = []
+        for i in range(center_idx - win_size, center_idx + win_size + 1):
+            if i < 0:
+                cur_audio_window.append(padding)
+            elif i >= len(audio_feat):
+                cur_audio_window.append(padding)
+            else:
+                cur_audio_window.append(audio_feat[i])
+        cur_audio_win_array = np.stack(cur_audio_window, axis=0)
+        audio_window_list.append(cur_audio_win_array)
+
+    audio_window_array = np.stack(audio_window_list, axis=0)
+    return audio_window_array
+
+
+def reshape_audio_feat(style_audio_all_raw, stride):
+    """_summary_
+
+    Args:
+        style_audio_all_raw (_type_): (stride * L, C)
+        stride (_type_): int
+
+    Returns:
+        _type_: (L, C * stride)
+    """
+    style_audio_all_raw = style_audio_all_raw[
+        : style_audio_all_raw.shape[0] // stride * stride
+    ]
+    style_audio_all_raw = style_audio_all_raw.reshape(
+        style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
+    )
+    style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
+    return style_audio_all
+
+
+import random
+
+
+def get_derangement_tuple(n):
+    while True:
+        v = [i for i in range(n)]
+        for j in range(n - 1, -1, -1):
+            p = random.randint(0, j)
+            if v[p] == j:
+                break
+            else:
+                v[j], v[p] = v[p], v[j]
+        else:
+            if v[0] != 0:
+                return tuple(v)
+
+
+def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
+    left, top, right, bot = bbox
+    width = right - left
+    height = bot - top
+
+    width_increase = max(
+        increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
+    )
+    height_increase = max(
+        increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
+    )
+
+    left_t = int(left - width_increase * width)
+    top_t = int(top - height_increase * height)
+    right_t = int(right + width_increase * width)
+    bot_t = int(bot + height_increase * height)
+
+    left_oob = -min(0, left_t)
+    right_oob = right - min(right_t, w)
+    top_oob = -min(0, top_t)
+    bot_oob = bot - min(bot_t, h)
+
+    if max(left_oob, right_oob, top_oob, bot_oob) > 0:
+        max_w = max(left_oob, right_oob)
+        max_h = max(top_oob, bot_oob)
+        if max_w > max_h:
+            return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
+        else:
+            return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
+
+    else:
+        return (left_t, top_t, right_t, bot_t)
+
+
+def crop_src_image(src_img, save_img, increase_ratio, detector=None):
+    if detector is None:
+        detector = dlib.get_frontal_face_detector()
+
+    img = cv2.imread(src_img)
+    faces = detector(img, 0)
+    h, width, _ = img.shape
+    if len(faces) > 0:
+        bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
+        l = bbox[3] - bbox[1]
+        bbox[1] = bbox[1] - l * 0.1
+        bbox[3] = bbox[3] - l * 0.1
+        bbox[1] = max(0, bbox[1])
+        bbox[3] = min(h, bbox[3])
+        bbox = compute_aspect_preserved_bbox(
+            tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
+        )
+        img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
+        img = cv2.resize(img, (256, 256))
+        cv2.imwrite(save_img, img)
+    else:
+        raise ValueError("No face detected in the input image")
+        # img = cv2.resize(img, (256, 256))
+        # cv2.imwrite(save_img, img)

BIN
data/audio/German1.wav


BIN
data/audio/German2.wav


BIN
data/audio/German3.wav


BIN
data/audio/German4.wav


BIN
data/audio/acknowledgement_chinese.m4a


BIN
data/audio/acknowledgement_english.m4a


BIN
data/audio/chinese1_haierlizhi.wav


BIN
data/audio/chinese2_guanyu.wav


BIN
data/audio/french1.wav


BIN
data/audio/french2.wav


BIN
data/audio/french3.wav


BIN
data/audio/italian1.wav


BIN
data/audio/italian2.wav


BIN
data/audio/italian3.wav


BIN
data/audio/japan1.wav


BIN
data/audio/japan2.wav


BIN
data/audio/japan3.wav


BIN
data/audio/korean1.wav


BIN
data/audio/korean2.wav


BIN
data/audio/korean3.wav


BIN
data/audio/noisy_audio_cafeter_snr_0.wav


BIN
data/audio/noisy_audio_meeting_snr_0.wav


BIN
data/audio/noisy_audio_meeting_snr_10.wav


BIN
data/audio/noisy_audio_meeting_snr_20.wav


BIN
data/audio/noisy_audio_narrative.wav


BIN
data/audio/noisy_audio_office_snr_0.wav


BIN
data/audio/out_of_domain_narrative.wav


BIN
data/audio/spanish1.wav


BIN
data/audio/spanish2.wav


BIN
data/audio/spanish3.wav


BIN
data/pose/RichardShelby_front_neutral_level1_001.mat


BIN
data/src_img/cropped/chpa5.png


BIN
data/src_img/cropped/cut_img.png


BIN
data/src_img/cropped/f30.png


BIN
data/src_img/cropped/menglu2.png


BIN
data/src_img/cropped/nscu2.png


BIN
data/src_img/cropped/zp1.png


BIN
data/src_img/cropped/zt12.png


BIN
data/src_img/uncropped/face3.png


BIN
data/src_img/uncropped/male_face.png


BIN
data/src_img/uncropped/uncut_src_img.jpg


BIN
data/style_clip/3DMM/M030_front_angry_level3_001.mat


BIN
data/style_clip/3DMM/M030_front_contempt_level3_001.mat


BIN
data/style_clip/3DMM/M030_front_disgusted_level3_001.mat


BIN
data/style_clip/3DMM/M030_front_fear_level3_001.mat


BIN
data/style_clip/3DMM/M030_front_happy_level3_001.mat


BIN
data/style_clip/3DMM/M030_front_neutral_level1_001.mat


BIN
data/style_clip/3DMM/M030_front_sad_level3_001.mat


BIN
data/style_clip/3DMM/M030_front_surprised_level3_001.mat


BIN
data/style_clip/3DMM/W009_front_angry_level3_001.mat


BIN
data/style_clip/3DMM/W009_front_contempt_level3_001.mat


BIN
data/style_clip/3DMM/W009_front_disgusted_level3_001.mat


BIN
data/style_clip/3DMM/W009_front_fear_level3_001.mat


BIN
data/style_clip/3DMM/W009_front_happy_level3_001.mat


BIN
data/style_clip/3DMM/W009_front_neutral_level1_001.mat


BIN
data/style_clip/3DMM/W009_front_sad_level3_001.mat


BIN
data/style_clip/3DMM/W009_front_surprised_level3_001.mat


BIN
data/style_clip/3DMM/W011_front_angry_level3_001.mat


BIN
data/style_clip/3DMM/W011_front_contempt_level3_001.mat


BIN
data/style_clip/3DMM/W011_front_disgusted_level3_001.mat


BIN
data/style_clip/3DMM/W011_front_fear_level3_001.mat


BIN
data/style_clip/3DMM/W011_front_happy_level3_001.mat


BIN
data/style_clip/3DMM/W011_front_neutral_level1_001.mat


BIN
data/style_clip/3DMM/W011_front_sad_level3_001.mat


BIN
data/style_clip/3DMM/W011_front_surprised_level3_001.mat


BIN
data/style_clip/video/M030_front_angry_level3_001.mp4


BIN
data/style_clip/video/M030_front_contempt_level3_001.mp4


BIN
data/style_clip/video/M030_front_disgusted_level3_001.mp4


BIN
data/style_clip/video/M030_front_fear_level3_001.mp4


BIN
data/style_clip/video/M030_front_happy_level3_001.mp4


BIN
data/style_clip/video/M030_front_neutral_level1_001.mp4


BIN
data/style_clip/video/M030_front_sad_level3_001.mp4


BIN
data/style_clip/video/M030_front_surprised_level3_001.mp4


BIN
data/style_clip/video/W009_front_angry_level3_001.mp4


BIN
data/style_clip/video/W009_front_contempt_level3_001.mp4


BIN
data/style_clip/video/W009_front_disgusted_level3_001.mp4


BIN
data/style_clip/video/W009_front_fear_level3_001.mp4


BIN
data/style_clip/video/W009_front_happy_level3_001.mp4


BIN
data/style_clip/video/W009_front_neutral_level1_001.mp4


BIN
data/style_clip/video/W009_front_sad_level3_001.mp4


BIN
data/style_clip/video/W009_front_surprised_level3_001.mp4


BIN
data/style_clip/video/W011_front_angry_level3_001.mp4


BIN
data/style_clip/video/W011_front_contempt_level3_001.mp4


BIN
data/style_clip/video/W011_front_disgusted_level3_001.mp4


BIN
data/style_clip/video/W011_front_fear_level3_001.mp4


Some files were not shown because too many files changed in this diff