ESC
输入关键词搜索文章
目录

数字人工程解读(五):Ditto 源码:把 Talking Head 做成实时流水线

Context
这个仓库解决什么问题

antgroup/ditto-talkinghead 是 Ant Group 开源的实时 Talking Head 生成项目,对应论文 Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis。仓库主分支面向推理,训练代码放在 train 分支。#Ditto GitHub #Ditto Paper

一句话判断:Ditto 不是一个单纯的 lipsync demo,而是把人脸预处理、音频特征、运动扩散、运动缝合、3D feature warping、图像解码和视频写出组织成一条可流式运行的工程流水线。

它的输入是一张人脸图或一段 source video,再加一段音频;输出是一段跟随音频说话的头像视频。README 给出的推理命令显示,用户主要通过 inference.py 传入 --data_root--cfg_pkl--audio_path--source_path--output_path

Ditto example source image
图 1:仓库 example/image.png 中的示例 source image。推理时它会先经过人脸检测、裁剪、关键点与 appearance feature 提取,再进入后续驱动流程。
Usage
先看怎么安装、推理和训练

读模型仓库时,使用路径会暴露真实工程边界:依赖什么硬件、权重如何组织、入口参数是什么、训练和推理是否在同一个分支。

安装与模型下载

README 推荐用 environment.yaml 创建 conda 环境;如果不用 conda,则要自行准备 PyTorch、CUDA、cuDNN、TensorRT、librosa、OpenCV、cuda-python 和 ffmpeg 等依赖。README 标注测试环境为 CentOS 7.2、NVIDIA A100、Python 3.10、PyTorch 2.5.1、CUDA 12.1、TensorRT 8.6.1。#Ditto GitHub

conda env create -f environment.yaml
conda activate ditto

权重从 HuggingFace 下载到本地 checkpoints 目录。README 中列出的权重形态包括 ditto_cfgditto_onnxditto_trt_Ampere_Plusditto_pytorch。本次阅读中 HuggingFace 页面访问超时,因此只采用 README 中能核实到的说明。#Ditto HuggingFace

推理命令

README 的推理命令本质上是在告诉 inference.py 三件事:模型在哪里、输入素材在哪里、结果写到哪里。

python inference.py \
  --data_root "./checkpoints/ditto_trt_Ampere_Plus" \
  --cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" \
  --audio_path "./example/audio.wav" \
  --source_path "./example/image.png" \
  --output_path "./tmp/result.mp4"
参数作用怎么替换
--data_root模型文件根目录。parse_cfg 会把配置里的模型相对路径拼到这个目录下,因此它决定本次使用 TensorRT、ONNX 还是 PyTorch 权重。想用 TensorRT 就指向 ditto_trt_Ampere_Plus 或自己转换出的 TensorRT 目录;想用 PyTorch 就指向 ditto_pytorch
--cfg_pkl推理配置文件,里面记录各 atomic component 使用哪些模型、输入输出维度和运行参数。必须和 data_root 的模型格式匹配。例如 TensorRT 用 v0.4_hubert_cfg_trt.pkl,PyTorch 用 v0.4_hubert_cfg_pytorch.pkl
--audio_path驱动音频。源码会用 librosa 以 16 kHz 读取,再交给 Wav2Feat 提取 HuBERT 特征。换成自己的 .wav 音频路径。音频内容决定嘴型、表情节奏和说话时长。
--source_path被驱动的人脸素材,可以是 source image,也可以按仓库支持传 source video。换成自己的头像图片或视频。这个素材提供身份、初始姿态、appearance feature 和 source keypoints。
--output_path最终输出视频路径。pipeline 会先写临时视频,再用 ffmpeg 合成原音频得到这个 mp4。换成你想保存的路径,例如 ./tmp/my-result.mp4;目录要可写。
如何判断命令是否配对正确

--data_root--cfg_pkl 要作为一组看:TensorRT 目录配 TensorRT 配置,PyTorch 目录配 PyTorch 配置。如果 GPU 架构不兼容仓库提供的 TensorRT engine,就先用 scripts/cvt_onnx_to_trt.py 从 ONNX 转成本机 engine,再把 --data_root 指向转换后的目录。

训练在 train 分支

主分支没有完整训练入口。训练代码位于 train 分支,其中入口是 MotionDiT/train.py,数据准备脚本位于 prepare_data,Quick Start 使用 accelerate launch train.py 启动。训练分支 README 提醒:示例训练只用于走通流程,要得到合理效果需要更多高质量数据和更长训练。

git checkout train
cd MotionDiT
accelerate launch train.py \
  --experiment_dir ${EXP_DIR} \
  --experiment_name ${EXP_NAME} \
  --use_sc \
  --use_last_frame \
  --use_last_frame_loss \
  --use_emo \
  --use_eye_open \
  --use_eye_ball \
  --audio_feat_dim 1103 \
  --motion_feat_dim 265
参数作用阅读代码时对应哪里
--experiment_dir实验输出根目录,用来保存日志、配置快照和 checkpoint。Trainer._init_log 与 checkpoint 保存逻辑。
--experiment_name本次实验名称,通常作为输出目录的一部分。训练日志目录和 checkpoint 目录组织。
--use_sc启用 source canonical / source 相关条件,让 motion 生成感知 source 身份与结构。Stage2Dataset 组织 aud_cond 和附加条件。
--use_last_frame把上一帧或条件帧作为生成条件,帮助片段之间保持连续。kp_condcond_frame 进入 LMDM。
--use_last_frame_loss增加 last-frame 相关损失,约束片段边界或条件帧附近的连续性。MotionDiffusion.p_losses 中的 last-frame loss 分支。
--use_emo把 emotion 特征加入条件,影响表情生成。Stage2Dataset 加载 emo 特征并拼入条件。
--use_eye_open / --use_eye_ball加入眼睛开合和眼球方向条件,帮助控制眨眼与视线。训练数据中的 eye 特征,以及推理侧 MotionStitch 的眼睛控制逻辑。
--audio_feat_dim音频条件维度。README 示例为 1103,需要和预处理出来的音频特征一致。MotionDecoder 的条件投影层输入维度。
--motion_feat_dimmotion 表示维度。README 示例为 265,需要和数据集导出的 motion feature 一致。LMDM / MotionDecoder 的 motion 输入输出维度。
Architecture
从运行单元看整体架构

主分支最重要的不是单个模型定义,而是 inference.pystream_pipeline_offline.py / stream_pipeline_online.pycore/atomic_components 下的一组原子组件。它们共同把“音频驱动头像”拆成多阶段队列流水线。

flowchart TD
  User["用户输入
source image/video + audio"] --> CLI["inference.py
命令行入口"] CLI --> SDK["StreamSDK
推理编排"] SDK --> CFG["parse_cfg
cfg_pkl + data_root"] CFG --> Loader["load_model
ONNX / TensorRT / PyTorch"] SDK --> Source["Source2Info / AvatarRegistrar
人脸裁剪、关键点、appearance feature"] SDK --> Audio["Wav2Feat
HuBERT 音频特征"] Source --> Cond["ConditionHandler
source canonical + audio + emotion + eye"] Audio --> Cond Cond --> A2M["Audio2Motion
调用 LMDM"] A2M --> Stitch["MotionStitch
运动缝合与控制"] Stitch --> Warp["WarpF3D
3D feature warping"] Warp --> Decode["DecodeF3D
图像解码"] Decode --> Putback["PutBack
回贴到原图"] Putback --> Writer["VideoWriter + ffmpeg
输出 mp4"]
层级仓库证据职责
入口inference.py解析参数、初始化 SDK、读取音频、触发推理
流水线stream_pipeline_offline.py / stream_pipeline_online.py创建队列与 worker,把各阶段串成可流式执行的 pipeline
模型加载core/utils/load_model.py按后缀统一加载 ONNX、TensorRT、PyTorch 模型
运动扩散core/atomic_components/audio2motion.pycore/models/lmdm.py把音频和条件转为 motion sequence
运动控制core/atomic_components/motion_stitch.py融合 source motion、driving motion、眼睛和姿态控制
Code Path
推理主调用链怎么跑起来

inference.py 是一个很薄的入口:它解析命令行参数,创建 StreamSDK,调用 run。真正的工程复杂度在 StreamSDK.setup 和各个 worker 中。

inference.py
  -> StreamSDK(cfg_pkl, data_root)
  -> SDK.setup(source_path, output_path)
  -> librosa.load(audio_path, sr=16000)
  -> Wav2Feat.wav2feat(audio)
  -> audio2motion_queue.put(audio_feature)
  -> 多线程 worker pipeline
  -> ffmpeg 合成临时视频与原音频

入口源码:run 如何把音频送进流水线

这段是真实入口代码。它先用 SDK.setup 注册 source,再把音频固定重采样到 16 kHz;离线模式下,HuBERT 特征会被放入 audio2motion_queue,后续 worker 从队列继续消费。

inference.py
run

def run(SDK: StreamSDK, audio_path: str, source_path: str, output_path: str, more_kwargs: str | dict = {}):

    if isinstance(more_kwargs, str):
        more_kwargs = load_pkl(more_kwargs)
    setup_kwargs = more_kwargs.get("setup_kwargs", {})
    run_kwargs = more_kwargs.get("run_kwargs", {})

    SDK.setup(source_path, output_path, **setup_kwargs)

    audio, sr = librosa.core.load(audio_path, sr=16000)
    num_f = math.ceil(len(audio) / 16000 * 25)

    fade_in = run_kwargs.get("fade_in", -1)
    fade_out = run_kwargs.get("fade_out", -1)
    ctrl_info = run_kwargs.get("ctrl_info", {})
    SDK.setup_Nd(N_d=num_f, fade_in=fade_in, fade_out=fade_out, ctrl_info=ctrl_info)

    online_mode = SDK.online_mode
    if online_mode:
        chunksize = run_kwargs.get("chunksize", (3, 5, 2))
        audio = np.concatenate([np.zeros((chunksize[0] * 640,), dtype=np.float32), audio], 0)
        split_len = int(sum(chunksize) * 0.04 * 16000) + 80  # 6480
        for i in range(0, len(audio), chunksize[1] * 640):
            audio_chunk = audio[i:i + split_len]
            if len(audio_chunk) < split_len:
                audio_chunk = np.pad(audio_chunk, (0, split_len - len(audio_chunk)), mode="constant")
            SDK.run_chunk(audio_chunk, chunksize)
    else:
        aud_feat = SDK.wav2feat.wav2feat(audio)
        SDK.audio2motion_queue.put(aud_feat)
    SDK.close()

    cmd = f'ffmpeg -loglevel error -y -i "{SDK.tmp_output_path}" -i "{audio_path}" -map 0:v -map 1:a -c:v copy -c:a aac "{output_path}"'
    print(cmd)
    os.system(cmd)
sequenceDiagram
  participant CLI as inference.py
  participant SDK as StreamSDK
  participant A2M as Audio2Motion worker
  participant Stitch as MotionStitch worker
  participant Render as Warp/Decode/PutBack workers
  participant Writer as Writer/ffmpeg
  CLI->>SDK: "setup(source_path, output_path)"
  CLI->>SDK: "run(audio_path)"
  SDK->>A2M: "audio feature queue"
  A2M->>Stitch: "motion sequence"
  Stitch->>Render: "stitched keypoints"
  Render->>Writer: "frame images"
  Writer-->>CLI: "result.mp4"
Algorithm: Ditto offline inference
Input: source_path, audio_path, cfg_pkl, data_root
1. Parse cfg_pkl and bind every model path to data_root.
2. Register source image/video: crop face, extract keypoints and 3D feature.
3. Convert waveform to HuBERT audio feature.
4. Feed audio feature into audio2motion queue.
5. LMDM samples motion sequence in motion space.
6. MotionStitch mixes source motion, driving motion, gaze and pose controls.
7. Warp 3D feature, decode crop image, put crop back to original canvas.
8. Write temporary video and mux original audio by ffmpeg.
Output: result.mp4
Motion Diffusion
LMDM:扩散模型落在 motion space

Ditto 的关键不是直接在像素空间扩散,而是在 motion 表示上做 diffusion。推理侧的核心路径是 core/atomic_components/audio2motion.py 调用 core/models/lmdm.py,再进入 core/models/modules/LMDM.pycore/models/modules/lmdm_modules/model.py

flowchart LR
  X["noisy motion x_t"] --> Decoder["MotionDecoder"]
  Frame["cond_frame"] --> Decoder
  Audio["audio condition"] --> Tokens["condition tokens"]
  EmoEye["emotion / eye / source canonical"] --> Tokens
  T["timestep embedding"] --> Decoder
  Tokens --> Decoder
  Decoder --> X0["predicted x_start"]
  X0 --> DDIM["DDIM update"]
  DDIM --> Next["x_t_minus_1"]

MotionDecoder 会把当前 noisy motion、条件帧、音频条件以及 timestep embedding 组合起来。源码中的 guided_forward 是 classifier-free guidance 风格:先算无条件输出,再算有条件输出,然后用 guidance_weight 放大二者差值。

核心源码:guidance 与 DDIM 采样

guided_forward 不是伪概念,而是源码里直接实现的 classifier-free guidance:同一批 noisy motion 分别跑无条件和有条件分支,再用 guidance_weight 放大条件分支带来的差异。

core/models/modules/lmdm_modules/model.py
MotionDecoder.guided_forward / forward

    def guided_forward(self, x, cond_frame, cond_embed, times, guidance_weight):
        unc = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=1)
        conditioned = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=0)

        return unc + (conditioned - unc) * guidance_weight

    def forward(
        self, x: Tensor, cond_frame: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
    ):
        batch_size, device = x.shape[0], x.device

        # concat last frame, project to latent space
        # cond_frame: [b, dim] | [b, n, dim+1]
        if self.multi_cond_frame:
            # [b, n, dim+1] (+1 mask)
            x = torch.cat([x, cond_frame], dim=-1)
        else:
            # [b, dim]
            x = torch.cat([x, cond_frame.unsqueeze(1).repeat(1, x.shape[1], 1)], dim=-1)
        x = self.input_projection(x)
        # add the positional embeddings of the input sequence to provide temporal information
        x = self.abs_pos_encoding(x)

        # create audio conditional embedding with conditional dropout
        keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
        keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
        keep_mask_hidden = rearrange(keep_mask, "b -> b 1")

        cond_tokens = self.cond_projection(cond_embed)
        # encode tokens
        cond_tokens = self.abs_pos_encoding(cond_tokens)
        cond_tokens = self.cond_encoder(cond_tokens)

        null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
        cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)

ddim_sample 则负责从高斯噪声开始迭代更新 motion。每一步先预测 x_start 与噪声,再按 DDIM 公式更新到下一个时间步。

core/models/modules/LMDM.py
MotionDiffusion.ddim_sample

    @torch.no_grad()
    def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps):
        self.setup(sampling_timesteps)

        cond_frame = kp_cond
        cond = aud_cond

        shape = (1, self.seq_frames, self.motion_feat_dim)
        x = torch.randn(shape, device=self.device)

        x_start = None
        i = 0
        for _, time_next in self.time_pairs:
            time_cond = self.time_cond_list[i]
            pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
            if time_next < 0:
                x = x_start
                continue

            alpha_next_sqrt = self.alpha_next_sqrt_list[i]
            c = self.c_list[i]
            sigma = self.sigma_list[i]
            noise = self.noise_list[i]
            x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise

            i += 1
        return x  # pred_kp_seq
设计直觉:把扩散放在 motion space,可以把生成问题从“每帧像素怎么画”转成“表情、姿态、关键点和眼睛怎么随音频演化”。
Training
训练分支如何组织数据和 loss

训练分支的主入口是 MotionDiT/train.py。它通过 tyro 解析训练参数,创建 Trainer,再由 Trainer.train_loop 执行训练。关键路径包括 MotionDiT/src/trainers/trainer.pyMotionDiT/src/datasets/s2_dataset_v2.pyMotionDiT/src/models/LMDM.pyMotionDiT/src/models/modules/diffusion.py

flowchart TD
  TrainPy["MotionDiT/train.py"] --> Trainer["Trainer"]
  Trainer --> Dataset["Stage2Dataset"]
  Dataset --> Batch["kp_seq / kp_cond / aud_cond"]
  Trainer --> LMDMTrain["LMDM.diffusion"]
  Batch --> LMDMTrain
  LMDMTrain --> Loss["MotionDiffusion.p_losses"]
  Loss --> Backward["loss backward + optimizer step"]

Stage2Dataset 会加载 motion、audio、emotion、eye 等预处理特征,切成固定长度片段,返回 kp_seqkp_condaud_condp_losses 不只计算 position loss,还支持 velocity、acceleration、last-frame loss 和 regularization loss。

Engineering Notes
工程取舍与复现边界

core/utils/load_model.py 通过文件后缀选择 ONNX Runtime、TensorRT 或 PyTorch backend。这让上层 atomic component 不必关心底层模型格式,也解释了为什么 README 可以提供 TensorRT、ONNX、PyTorch 三类权重路径。

模型加载源码:用后缀切换推理后端

这里是可替换 backend 的关键:.onnx 走 ONNX Runtime,.engine / .trt 走 TensorRT,.pt / .pth 才实例化 PyTorch module。

core/utils/load_model.py
load_model

def load_model(model_path: str, device: str = "cuda", **kwargs):
    if kwargs.get("force_ori_type", False):
        # for hubert, landmark, retinaface, mediapipe
        model = load_force_ori_type(model_path, device, **kwargs)
        return model, "ori"

    if model_path.endswith(".onnx"):
        # onnx
        import onnxruntime

        if device == "cuda":
            providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
        else:
            providers = ["CPUExecutionProvider"]
        model = onnxruntime.InferenceSession(model_path, providers=providers)
        return model, "onnx"

    elif model_path.endswith(".engine") or model_path.endswith(".trt"):
        # tensorRT
        from .tensorrt_utils import TRTWrapper

        model = TRTWrapper(model_path)
        return model, "tensorrt"

    elif model_path.endswith(".pt") or model_path.endswith(".pth"):
        # pytorch
        model = create_model(model_path, device, **kwargs)
        return model, "pytorch"

运动缝合源码:source 身份与 driving motion 如何混合

MotionStitch 是控制感最强的一段代码:它先按 use_d_keys 混合 source 与 driving motion,再处理眼睛、VAD、姿态偏移和 fade,最后把 source / driving motion 都变成 keypoint,并可选调用 stitching network 修正。

core/atomic_components/motion_stitch.py
MotionStitch.__call__

    def __call__(self, x_s_info, x_d_info, **kwargs):
        # return x_s, x_d

        kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs)

        if self.scale_ratio is None:
            self.scale_b = x_s_info['scale'].item()
            self.scale_ratio = self.scale_a / self.scale_b
            self._set_scale_ratio(self.scale_ratio)

        if self.relative_d and self.d0 is None:
            self.d0 = copy.deepcopy(x_d_info)

        x_d_info = _mix_s_d_info(
            x_s_info,
            x_d_info,
            self.use_d_keys,
            self.d0,
        )

        delta_eye = 0
        if self.drive_eye and self.delta_eye_arr is not None:
            delta_eye = self.delta_eye_arr[
                self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)]
            ][None]
        x_d_info = _fix_exp_for_x_d_info_v2(
            x_d_info,
            x_s_info,
            delta_eye,
            self.fix_exp_a1,
            self.fix_exp_a2,
            self.fix_exp_a3,
        )

        if kwargs.get("vad_alpha", 1) < 1:
            x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1))

        x_d_info = ctrl_motion(x_d_info, **kwargs)

        if self.fade_type == "d0" and self.fade_dst is None:
            self.fade_dst = copy.deepcopy(x_d_info)

        # fade
        if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]:
            fade_alpha = kwargs["fade_alpha"]
            fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys)
            if self.fade_type == "d0":
                fade_dst = self.fade_dst
            elif self.fade_type == "s":
                if self.fade_dst is not None:
                    fade_dst = self.fade_dst
                else:
                    fade_dst = copy.deepcopy(x_s_info)
                    if self.is_image_flag:
                        self.fade_dst = fade_dst
            x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys)

        if self.drive_eye:
            if self.pose_s is None:
                yaw_s = bin66_to_degree(x_s_info['yaw']).item()
                pitch_s = bin66_to_degree(x_s_info['pitch']).item()
                self.pose_s = [yaw_s, pitch_s]
            x_d_info = _fix_gaze(self.pose_s, x_d_info)

        if self.x_s is not None:
            x_s = self.x_s
        else:
            x_s = transform_keypoint(x_s_info)
            if self.is_image_flag:
                self.x_s = x_s
        
        x_d = transform_keypoint(x_d_info)
        
        if self.flag_stitching:
            x_d = self.stitch_net(x_s, x_d)

从源码结构看,实时性并不是单靠一个模型很快,而是来自“分阶段队列 + worker 并行 + 可替换推理后端”的组合。

复现边界

本机没有 checkpoints、没有可用 nvidia-smi、也没有 conda,因此这次没有实际跑通推理命令。本文对运行方式的说明来自 README 和源码;对输出效果不做本机实测背书。

Takeaways
读完这个仓库能带走什么

复习速查

  • 项目定位:Ditto 是一个实时可控 Talking Head 系统,主分支负责推理,训练在 train 分支。
  • 核心路径inference.py -> StreamSDK -> Audio2Motion/LMDM -> MotionStitch -> WarpF3D -> DecodeF3D -> PutBack -> Writer
  • 关键设计:扩散模型不直接生成像素,而是在 motion space 生成运动,再通过 feature warping 和 decoder 变成视频帧。
  • 工程价值:这个仓库展示了研究模型如何落成实时 pipeline:后端可替换、阶段可并行、控制逻辑独立于生成模型。
  • 复现提醒:TensorRT、CUDA、GPU 架构和外部 checkpoint 是主要门槛;训练示例只是流程验证,不代表最终效果。

参考来源