数字人工程解读(五):Ditto 源码:把 Talking Head 做成实时流水线
antgroup/ditto-talkinghead 是 Ant Group 开源的实时 Talking Head 生成项目,对应论文 Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis。仓库主分支面向推理,训练代码放在 train 分支。#Ditto GitHub #Ditto Paper
它的输入是一张人脸图或一段 source video,再加一段音频;输出是一段跟随音频说话的头像视频。README 给出的推理命令显示,用户主要通过 inference.py 传入 --data_root、--cfg_pkl、--audio_path、--source_path 和 --output_path。

example/image.png 中的示例 source image。推理时它会先经过人脸检测、裁剪、关键点与 appearance feature 提取,再进入后续驱动流程。读模型仓库时,使用路径会暴露真实工程边界:依赖什么硬件、权重如何组织、入口参数是什么、训练和推理是否在同一个分支。
安装与模型下载
README 推荐用 environment.yaml 创建 conda 环境;如果不用 conda,则要自行准备 PyTorch、CUDA、cuDNN、TensorRT、librosa、OpenCV、cuda-python 和 ffmpeg 等依赖。README 标注测试环境为 CentOS 7.2、NVIDIA A100、Python 3.10、PyTorch 2.5.1、CUDA 12.1、TensorRT 8.6.1。#Ditto GitHub
conda env create -f environment.yaml
conda activate ditto权重从 HuggingFace 下载到本地 checkpoints 目录。README 中列出的权重形态包括 ditto_cfg、ditto_onnx、ditto_trt_Ampere_Plus 和 ditto_pytorch。本次阅读中 HuggingFace 页面访问超时,因此只采用 README 中能核实到的说明。#Ditto HuggingFace
推理命令
README 的推理命令本质上是在告诉 inference.py 三件事:模型在哪里、输入素材在哪里、结果写到哪里。
python inference.py \
--data_root "./checkpoints/ditto_trt_Ampere_Plus" \
--cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" \
--audio_path "./example/audio.wav" \
--source_path "./example/image.png" \
--output_path "./tmp/result.mp4"| 参数 | 作用 | 怎么替换 |
|---|---|---|
--data_root | 模型文件根目录。parse_cfg 会把配置里的模型相对路径拼到这个目录下,因此它决定本次使用 TensorRT、ONNX 还是 PyTorch 权重。 | 想用 TensorRT 就指向 ditto_trt_Ampere_Plus 或自己转换出的 TensorRT 目录;想用 PyTorch 就指向 ditto_pytorch。 |
--cfg_pkl | 推理配置文件,里面记录各 atomic component 使用哪些模型、输入输出维度和运行参数。 | 必须和 data_root 的模型格式匹配。例如 TensorRT 用 v0.4_hubert_cfg_trt.pkl,PyTorch 用 v0.4_hubert_cfg_pytorch.pkl。 |
--audio_path | 驱动音频。源码会用 librosa 以 16 kHz 读取,再交给 Wav2Feat 提取 HuBERT 特征。 | 换成自己的 .wav 音频路径。音频内容决定嘴型、表情节奏和说话时长。 |
--source_path | 被驱动的人脸素材,可以是 source image,也可以按仓库支持传 source video。 | 换成自己的头像图片或视频。这个素材提供身份、初始姿态、appearance feature 和 source keypoints。 |
--output_path | 最终输出视频路径。pipeline 会先写临时视频,再用 ffmpeg 合成原音频得到这个 mp4。 | 换成你想保存的路径,例如 ./tmp/my-result.mp4;目录要可写。 |
--data_root 和 --cfg_pkl 要作为一组看:TensorRT 目录配 TensorRT 配置,PyTorch 目录配 PyTorch 配置。如果 GPU 架构不兼容仓库提供的 TensorRT engine,就先用 scripts/cvt_onnx_to_trt.py 从 ONNX 转成本机 engine,再把 --data_root 指向转换后的目录。
训练在 train 分支
主分支没有完整训练入口。训练代码位于 train 分支,其中入口是 MotionDiT/train.py,数据准备脚本位于 prepare_data,Quick Start 使用 accelerate launch train.py 启动。训练分支 README 提醒:示例训练只用于走通流程,要得到合理效果需要更多高质量数据和更长训练。
git checkout train
cd MotionDiT
accelerate launch train.py \
--experiment_dir ${EXP_DIR} \
--experiment_name ${EXP_NAME} \
--use_sc \
--use_last_frame \
--use_last_frame_loss \
--use_emo \
--use_eye_open \
--use_eye_ball \
--audio_feat_dim 1103 \
--motion_feat_dim 265| 参数 | 作用 | 阅读代码时对应哪里 |
|---|---|---|
--experiment_dir | 实验输出根目录,用来保存日志、配置快照和 checkpoint。 | Trainer._init_log 与 checkpoint 保存逻辑。 |
--experiment_name | 本次实验名称,通常作为输出目录的一部分。 | 训练日志目录和 checkpoint 目录组织。 |
--use_sc | 启用 source canonical / source 相关条件,让 motion 生成感知 source 身份与结构。 | Stage2Dataset 组织 aud_cond 和附加条件。 |
--use_last_frame | 把上一帧或条件帧作为生成条件,帮助片段之间保持连续。 | kp_cond、cond_frame 进入 LMDM。 |
--use_last_frame_loss | 增加 last-frame 相关损失,约束片段边界或条件帧附近的连续性。 | MotionDiffusion.p_losses 中的 last-frame loss 分支。 |
--use_emo | 把 emotion 特征加入条件,影响表情生成。 | Stage2Dataset 加载 emo 特征并拼入条件。 |
--use_eye_open / --use_eye_ball | 加入眼睛开合和眼球方向条件,帮助控制眨眼与视线。 | 训练数据中的 eye 特征,以及推理侧 MotionStitch 的眼睛控制逻辑。 |
--audio_feat_dim | 音频条件维度。README 示例为 1103,需要和预处理出来的音频特征一致。 | MotionDecoder 的条件投影层输入维度。 |
--motion_feat_dim | motion 表示维度。README 示例为 265,需要和数据集导出的 motion feature 一致。 | LMDM / MotionDecoder 的 motion 输入输出维度。 |
主分支最重要的不是单个模型定义,而是 inference.py、stream_pipeline_offline.py / stream_pipeline_online.py 和 core/atomic_components 下的一组原子组件。它们共同把“音频驱动头像”拆成多阶段队列流水线。
flowchart TD User["用户输入
source image/video + audio"] --> CLI["inference.py
命令行入口"] CLI --> SDK["StreamSDK
推理编排"] SDK --> CFG["parse_cfg
cfg_pkl + data_root"] CFG --> Loader["load_model
ONNX / TensorRT / PyTorch"] SDK --> Source["Source2Info / AvatarRegistrar
人脸裁剪、关键点、appearance feature"] SDK --> Audio["Wav2Feat
HuBERT 音频特征"] Source --> Cond["ConditionHandler
source canonical + audio + emotion + eye"] Audio --> Cond Cond --> A2M["Audio2Motion
调用 LMDM"] A2M --> Stitch["MotionStitch
运动缝合与控制"] Stitch --> Warp["WarpF3D
3D feature warping"] Warp --> Decode["DecodeF3D
图像解码"] Decode --> Putback["PutBack
回贴到原图"] Putback --> Writer["VideoWriter + ffmpeg
输出 mp4"]
| 层级 | 仓库证据 | 职责 |
|---|---|---|
| 入口 | inference.py | 解析参数、初始化 SDK、读取音频、触发推理 |
| 流水线 | stream_pipeline_offline.py / stream_pipeline_online.py | 创建队列与 worker,把各阶段串成可流式执行的 pipeline |
| 模型加载 | core/utils/load_model.py | 按后缀统一加载 ONNX、TensorRT、PyTorch 模型 |
| 运动扩散 | core/atomic_components/audio2motion.py、core/models/lmdm.py | 把音频和条件转为 motion sequence |
| 运动控制 | core/atomic_components/motion_stitch.py | 融合 source motion、driving motion、眼睛和姿态控制 |
inference.py 是一个很薄的入口:它解析命令行参数,创建 StreamSDK,调用 run。真正的工程复杂度在 StreamSDK.setup 和各个 worker 中。
inference.py
-> StreamSDK(cfg_pkl, data_root)
-> SDK.setup(source_path, output_path)
-> librosa.load(audio_path, sr=16000)
-> Wav2Feat.wav2feat(audio)
-> audio2motion_queue.put(audio_feature)
-> 多线程 worker pipeline
-> ffmpeg 合成临时视频与原音频入口源码:run 如何把音频送进流水线
这段是真实入口代码。它先用 SDK.setup 注册 source,再把音频固定重采样到 16 kHz;离线模式下,HuBERT 特征会被放入 audio2motion_queue,后续 worker 从队列继续消费。
inference.py- run
def run(SDK: StreamSDK, audio_path: str, source_path: str, output_path: str, more_kwargs: str | dict = {}):
if isinstance(more_kwargs, str):
more_kwargs = load_pkl(more_kwargs)
setup_kwargs = more_kwargs.get("setup_kwargs", {})
run_kwargs = more_kwargs.get("run_kwargs", {})
SDK.setup(source_path, output_path, **setup_kwargs)
audio, sr = librosa.core.load(audio_path, sr=16000)
num_f = math.ceil(len(audio) / 16000 * 25)
fade_in = run_kwargs.get("fade_in", -1)
fade_out = run_kwargs.get("fade_out", -1)
ctrl_info = run_kwargs.get("ctrl_info", {})
SDK.setup_Nd(N_d=num_f, fade_in=fade_in, fade_out=fade_out, ctrl_info=ctrl_info)
online_mode = SDK.online_mode
if online_mode:
chunksize = run_kwargs.get("chunksize", (3, 5, 2))
audio = np.concatenate([np.zeros((chunksize[0] * 640,), dtype=np.float32), audio], 0)
split_len = int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
for i in range(0, len(audio), chunksize[1] * 640):
audio_chunk = audio[i:i + split_len]
if len(audio_chunk) < split_len:
audio_chunk = np.pad(audio_chunk, (0, split_len - len(audio_chunk)), mode="constant")
SDK.run_chunk(audio_chunk, chunksize)
else:
aud_feat = SDK.wav2feat.wav2feat(audio)
SDK.audio2motion_queue.put(aud_feat)
SDK.close()
cmd = f'ffmpeg -loglevel error -y -i "{SDK.tmp_output_path}" -i "{audio_path}" -map 0:v -map 1:a -c:v copy -c:a aac "{output_path}"'
print(cmd)
os.system(cmd)
sequenceDiagram participant CLI as inference.py participant SDK as StreamSDK participant A2M as Audio2Motion worker participant Stitch as MotionStitch worker participant Render as Warp/Decode/PutBack workers participant Writer as Writer/ffmpeg CLI->>SDK: "setup(source_path, output_path)" CLI->>SDK: "run(audio_path)" SDK->>A2M: "audio feature queue" A2M->>Stitch: "motion sequence" Stitch->>Render: "stitched keypoints" Render->>Writer: "frame images" Writer-->>CLI: "result.mp4"
Algorithm: Ditto offline inference
Input: source_path, audio_path, cfg_pkl, data_root
1. Parse cfg_pkl and bind every model path to data_root.
2. Register source image/video: crop face, extract keypoints and 3D feature.
3. Convert waveform to HuBERT audio feature.
4. Feed audio feature into audio2motion queue.
5. LMDM samples motion sequence in motion space.
6. MotionStitch mixes source motion, driving motion, gaze and pose controls.
7. Warp 3D feature, decode crop image, put crop back to original canvas.
8. Write temporary video and mux original audio by ffmpeg.
Output: result.mp4Ditto 的关键不是直接在像素空间扩散,而是在 motion 表示上做 diffusion。推理侧的核心路径是 core/atomic_components/audio2motion.py 调用 core/models/lmdm.py,再进入 core/models/modules/LMDM.py 和 core/models/modules/lmdm_modules/model.py。
flowchart LR X["noisy motion x_t"] --> Decoder["MotionDecoder"] Frame["cond_frame"] --> Decoder Audio["audio condition"] --> Tokens["condition tokens"] EmoEye["emotion / eye / source canonical"] --> Tokens T["timestep embedding"] --> Decoder Tokens --> Decoder Decoder --> X0["predicted x_start"] X0 --> DDIM["DDIM update"] DDIM --> Next["x_t_minus_1"]
MotionDecoder 会把当前 noisy motion、条件帧、音频条件以及 timestep embedding 组合起来。源码中的 guided_forward 是 classifier-free guidance 风格:先算无条件输出,再算有条件输出,然后用 guidance_weight 放大二者差值。
核心源码:guidance 与 DDIM 采样
guided_forward 不是伪概念,而是源码里直接实现的 classifier-free guidance:同一批 noisy motion 分别跑无条件和有条件分支,再用 guidance_weight 放大条件分支带来的差异。
core/models/modules/lmdm_modules/model.py- MotionDecoder.guided_forward / forward
def guided_forward(self, x, cond_frame, cond_embed, times, guidance_weight):
unc = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=1)
conditioned = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=0)
return unc + (conditioned - unc) * guidance_weight
def forward(
self, x: Tensor, cond_frame: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
):
batch_size, device = x.shape[0], x.device
# concat last frame, project to latent space
# cond_frame: [b, dim] | [b, n, dim+1]
if self.multi_cond_frame:
# [b, n, dim+1] (+1 mask)
x = torch.cat([x, cond_frame], dim=-1)
else:
# [b, dim]
x = torch.cat([x, cond_frame.unsqueeze(1).repeat(1, x.shape[1], 1)], dim=-1)
x = self.input_projection(x)
# add the positional embeddings of the input sequence to provide temporal information
x = self.abs_pos_encoding(x)
# create audio conditional embedding with conditional dropout
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
cond_tokens = self.cond_projection(cond_embed)
# encode tokens
cond_tokens = self.abs_pos_encoding(cond_tokens)
cond_tokens = self.cond_encoder(cond_tokens)
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)ddim_sample 则负责从高斯噪声开始迭代更新 motion。每一步先预测 x_start 与噪声,再按 DDIM 公式更新到下一个时间步。
core/models/modules/LMDM.py- MotionDiffusion.ddim_sample
@torch.no_grad()
def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps):
self.setup(sampling_timesteps)
cond_frame = kp_cond
cond = aud_cond
shape = (1, self.seq_frames, self.motion_feat_dim)
x = torch.randn(shape, device=self.device)
x_start = None
i = 0
for _, time_next in self.time_pairs:
time_cond = self.time_cond_list[i]
pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
if time_next < 0:
x = x_start
continue
alpha_next_sqrt = self.alpha_next_sqrt_list[i]
c = self.c_list[i]
sigma = self.sigma_list[i]
noise = self.noise_list[i]
x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
i += 1
return x # pred_kp_seq训练分支的主入口是 MotionDiT/train.py。它通过 tyro 解析训练参数,创建 Trainer,再由 Trainer.train_loop 执行训练。关键路径包括 MotionDiT/src/trainers/trainer.py、MotionDiT/src/datasets/s2_dataset_v2.py、MotionDiT/src/models/LMDM.py 和 MotionDiT/src/models/modules/diffusion.py。
flowchart TD TrainPy["MotionDiT/train.py"] --> Trainer["Trainer"] Trainer --> Dataset["Stage2Dataset"] Dataset --> Batch["kp_seq / kp_cond / aud_cond"] Trainer --> LMDMTrain["LMDM.diffusion"] Batch --> LMDMTrain LMDMTrain --> Loss["MotionDiffusion.p_losses"] Loss --> Backward["loss backward + optimizer step"]
Stage2Dataset 会加载 motion、audio、emotion、eye 等预处理特征,切成固定长度片段,返回 kp_seq、kp_cond 和 aud_cond。p_losses 不只计算 position loss,还支持 velocity、acceleration、last-frame loss 和 regularization loss。
core/utils/load_model.py 通过文件后缀选择 ONNX Runtime、TensorRT 或 PyTorch backend。这让上层 atomic component 不必关心底层模型格式,也解释了为什么 README 可以提供 TensorRT、ONNX、PyTorch 三类权重路径。
模型加载源码:用后缀切换推理后端
这里是可替换 backend 的关键:.onnx 走 ONNX Runtime,.engine / .trt 走 TensorRT,.pt / .pth 才实例化 PyTorch module。
core/utils/load_model.py- load_model
def load_model(model_path: str, device: str = "cuda", **kwargs):
if kwargs.get("force_ori_type", False):
# for hubert, landmark, retinaface, mediapipe
model = load_force_ori_type(model_path, device, **kwargs)
return model, "ori"
if model_path.endswith(".onnx"):
# onnx
import onnxruntime
if device == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
model = onnxruntime.InferenceSession(model_path, providers=providers)
return model, "onnx"
elif model_path.endswith(".engine") or model_path.endswith(".trt"):
# tensorRT
from .tensorrt_utils import TRTWrapper
model = TRTWrapper(model_path)
return model, "tensorrt"
elif model_path.endswith(".pt") or model_path.endswith(".pth"):
# pytorch
model = create_model(model_path, device, **kwargs)
return model, "pytorch"
运动缝合源码:source 身份与 driving motion 如何混合
MotionStitch 是控制感最强的一段代码:它先按 use_d_keys 混合 source 与 driving motion,再处理眼睛、VAD、姿态偏移和 fade,最后把 source / driving motion 都变成 keypoint,并可选调用 stitching network 修正。
core/atomic_components/motion_stitch.py- MotionStitch.__call__
def __call__(self, x_s_info, x_d_info, **kwargs):
# return x_s, x_d
kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs)
if self.scale_ratio is None:
self.scale_b = x_s_info['scale'].item()
self.scale_ratio = self.scale_a / self.scale_b
self._set_scale_ratio(self.scale_ratio)
if self.relative_d and self.d0 is None:
self.d0 = copy.deepcopy(x_d_info)
x_d_info = _mix_s_d_info(
x_s_info,
x_d_info,
self.use_d_keys,
self.d0,
)
delta_eye = 0
if self.drive_eye and self.delta_eye_arr is not None:
delta_eye = self.delta_eye_arr[
self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)]
][None]
x_d_info = _fix_exp_for_x_d_info_v2(
x_d_info,
x_s_info,
delta_eye,
self.fix_exp_a1,
self.fix_exp_a2,
self.fix_exp_a3,
)
if kwargs.get("vad_alpha", 1) < 1:
x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1))
x_d_info = ctrl_motion(x_d_info, **kwargs)
if self.fade_type == "d0" and self.fade_dst is None:
self.fade_dst = copy.deepcopy(x_d_info)
# fade
if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]:
fade_alpha = kwargs["fade_alpha"]
fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys)
if self.fade_type == "d0":
fade_dst = self.fade_dst
elif self.fade_type == "s":
if self.fade_dst is not None:
fade_dst = self.fade_dst
else:
fade_dst = copy.deepcopy(x_s_info)
if self.is_image_flag:
self.fade_dst = fade_dst
x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys)
if self.drive_eye:
if self.pose_s is None:
yaw_s = bin66_to_degree(x_s_info['yaw']).item()
pitch_s = bin66_to_degree(x_s_info['pitch']).item()
self.pose_s = [yaw_s, pitch_s]
x_d_info = _fix_gaze(self.pose_s, x_d_info)
if self.x_s is not None:
x_s = self.x_s
else:
x_s = transform_keypoint(x_s_info)
if self.is_image_flag:
self.x_s = x_s
x_d = transform_keypoint(x_d_info)
if self.flag_stitching:
x_d = self.stitch_net(x_s, x_d)从源码结构看,实时性并不是单靠一个模型很快,而是来自“分阶段队列 + worker 并行 + 可替换推理后端”的组合。
本机没有 checkpoints、没有可用 nvidia-smi、也没有 conda,因此这次没有实际跑通推理命令。本文对运行方式的说明来自 README 和源码;对输出效果不做本机实测背书。
复习速查
- 项目定位:Ditto 是一个实时可控 Talking Head 系统,主分支负责推理,训练在
train分支。 - 核心路径:
inference.py -> StreamSDK -> Audio2Motion/LMDM -> MotionStitch -> WarpF3D -> DecodeF3D -> PutBack -> Writer。 - 关键设计:扩散模型不直接生成像素,而是在 motion space 生成运动,再通过 feature warping 和 decoder 变成视频帧。
- 工程价值:这个仓库展示了研究模型如何落成实时 pipeline:后端可替换、阶段可并行、控制逻辑独立于生成模型。
- 复现提醒:TensorRT、CUDA、GPU 架构和外部 checkpoint 是主要门槛;训练示例只是流程验证,不代表最终效果。
参考来源
- Ant Group. ditto-talkinghead. GitHub repository
- Li, Z. et al. Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis. arXiv:2411.19509
- digital-avatar. ditto-talkinghead checkpoints. HuggingFace model page