Masked Autoencoders Are Effective Tokenizers for Diffusion Models
MAETok(Masked Autoencoder Tokenizer)是 ICML 2025 的一篇论文,研究核心问题是:什么样的潜在空间对扩散模型是"好的"? 论文从理论和实验两个层面证明,潜在空间的 GMM mode 数量才是决定扩散模型生成质量的关键,而非 VAE 的变分约束形式。通过将 MAE 的 mask modeling 引入 plain AE tokenizer,MAETok 仅用 128 个 latent tokens 在 ImageNet 512×512 上达到 gFID 1.69 的 SOTA,训练速度提升 76×,推理吞吐提升 31×。
Image Tokenization
图像 tokenization 的发展脉络:早期 AE (Hinton & Salakhutdinov, 2006) → VAE (Kingma, 2013) → VQ-GAN (Esser et al., 2021)。近期 1D tokenizer 成为主流:TiTok (Yu et al., 2024) 用离散 token 序列表示图像;ImageFolder (Li et al., 2024) 和 SoftVQVAE (Chen et al., 2024) 探索连续 tokenization;VAVAE (Yao & Wang, 2025) 和 TexTok (Zha et al., 2024) 引入表征对齐的 VAE。MAETok 属于连续 tokenizer 路线,但完全抛弃了变分约束。
Image Generation
生成模型主要分为扩散模型和自回归模型两条路线。扩散模型方面:LDM (Rombach et al., 2022) 开创潜扩散范式,DiT (Peebles & Xie, 2023) 和 SiT (Ma et al., 2024) 改进骨干架构,REPA (Yu et al., 2024) 探索 representation alignment。自回归模型方面:MaskGIT、VAR、LlamaGen 等持续进步。
MAE 预训练
MAE (He et al., 2022) 开创了随机掩码重建的自监督学习范式,后续工作如 Xie et al., 2022 和 Wei et al., 2022 将其扩展到不同目标和更强特征。MAETok 将 MAE 的掩码建模思想引入 tokenizer 训练,填补了这一应用空白。
本文定位
MAETok 属于连续 tokenizer + 扩散模型的交叉路线。与 VAVAE/TexTok 不同的是:它不依赖变分约束,而是通过 MAE 的掩码建模直接学习判别性潜空间。与 TiTok/MAGVIT 不同的是:它不量化,输出连续 latent codes。
- Latent Diffusion Models
- 在 tokenizer 编码的低维潜在空间中进行扩散去噪,避免像素空间高维计算,典型如 SD-VAE 将 512×512 图像编码为 1024 个 latent tokens
- VAE vs AE
- VAE 通过 KL 约束使潜在分布平滑,但牺牲重建精度;AE 重建保真但潜在空间可能不够结构化。MAETok 证明了 AE + mask modeling 可以兼顾两者
- Masked Autoencoders (MAE)
- 随机遮蔽 40-80% 输入 patch,让模型从可见 patch 预测被遮蔽部分的特征,学到更判别性的表征
- GMM (Gaussian Mixture Model)
- 用 K 个高斯混合建模潜在分布,mode 数量反映分布复杂度。Mode 越多 → 分布越复杂 → 扩散模型越难学
- gFID / rFID
- gFID = generation FID(生成质量),rFID = reconstruction FID(重建质量)。MAETok 在两者上都优于 VAE baseline
- CFG (Classifier-Free Guidance)
- 扩散模型推理时用无条件预测引导条件预测,提升质量但增加计算量
- 什么样的潜在空间对扩散模型是"好的"?VAE 的变分约束真的是必要的吗?
- MAE 的 mask modeling 如何改善 tokenizer 的潜在空间?为什么它能减少 GMM mode 数量?
- 128 个 token 如何在 512×512 分辨率上达到 SOTA?压缩率这么高不会丢失细节吗?
- 辅助解码器为什么能同时学到多种目标特征(DINOv2 + CLIP + HOG)而不互相干扰?
- 两阶段训练(先 mask 训练 encoder,再 fine-tune decoder)的物理意义是什么?
论文证明:潜空间的 GMM 模式数越少,扩散模型学习越容易。在有限训练样本下,更多的 GMM 模式(VAE/AE)产生更差的生成质量。GMM 损失与扩散训练损失几乎对齐。
定理 2.1(GMM 样本复杂度)
DDPM 达到 $O(T\epsilon^2)$ 生成误差所需的样本数为:
其中 $K$ 是 GMM 模式数,$d$ 是维度,$B$ 是均值范数上界。关键的 $O(K^4)$ 关系意味着减少模式数可指数级降低训练难度。
核心设计:ViT-Base 编码器 + ViT-Base 解码器(共 176M 参数)。编码器输入 $N$ 个 image patch token + $L$ 个 learnable latent token($L=128$),随机掩码 40-60% 的 patch token,让 latent token 从 unmasked patches 聚合全局信息。辅助浅层解码器(3 层 Transformer)分别预测 HOG / DINOv2-Large / SigCLIP-Large 特征,仅在 masked 位置计算 MSE 损失。
高掩码比率会降低像素级重建质量(rFID 下降),但能改善生成质量(gFID 下降)。通过冻结编码器 + 微调解码器,可以在不牺牲潜空间判别性的前提下恢复重建保真度。
两阶段训练流程
- Stage 1 — Mask Modeling Training:带 mask modeling 训练整个 AE(encoder + pixel decoder + 3 个 auxiliary decoders),训练 500K iterations
- Stage 2 — Pixel Decoder Fine-Tuning:冻结 encoder,丢弃 auxiliary decoders,仅 fine-tune pixel decoder,mask ratio 从 60% 线性降到 0%,训练 50K iterations
Tokenizer 训练超参数
| 参数 | 值 |
|---|---|
| 架构 | ViT-Base encoder + ViT-Base decoder,共 176M 参数 |
| Latent space | L = 128 tokens,H = 32 维 |
| 训练迭代 | Stage 1: 500K / Stage 2: 50K |
| Mask ratio | Stage 1: 40-60% / Stage 2: 60% → 0% 线性衰减 |
| 辅助解码器 | 3 个(HOG + DINOv2-Large + SigCLIP-Large),各 3 层 Transformer |
| 损失权重 | λ₁ = 1.0(perceptual),λ₂ = 0.4(adversarial) |
| 数据集 | ImageNet 256×256 / 512×512 / LAION-COCO 512×512 |
| 初始化 | 从零初始化(无预训练权重) |
| 代码框架 | XQ-GAN codebase |
Diffusion 模型训练
| 模型 | 参数量 | 训练步数 | Patch size |
|---|---|---|---|
| SiT-XL | 675M | 4M steps | 1 |
| LightningDiT | 675M | 400K steps | 1 |
| SiT-L (ablation) | 458M | 400K steps | 1 |
所有 diffusion 模型使用 1D position embedding(适配 1D latent tokens),评估使用 250 inference steps,with/without CFG。
| 配置 | rFID ↓ | gFID ↓ |
|---|---|---|
| VAE(基线) | 1.22 | 22.17 |
| VAE + 掩码建模 | 1.75 | 18.17 |
| AE(基线) | 0.67 | 24.47 |
| AE + 掩码建模 | 0.85 | 5.78 |
| AE + MM + 解码器微调 | 0.48 | 5.69 |
AE + 掩码建模使 gFID 从 24.47 降至 5.78(降低 76%),而 VAE + MM 仅从 22.17 降至 18.17。KL 约束阻碍了潜空间学习。
| Aux. Decoder 深度 | rFID ↓ | gFID ↓ |
|---|---|---|
| 线性层 | 1.35 | 6.98 |
| 3 层(默认) | 0.85 | 5.78 |
| 12 层 | 0.96 | 8.80 |
过浅(1层):高层语义与低级细节混淆 → rFID 差。
过深(12层):容量过强,削弱 AE 的判别性潜空间 → gFID 差。3 层是最佳平衡点。
| 模型 | Tokenizer | 参数量 | Token 数 | gFID (无CFG) | gFID (有CFG) |
|---|---|---|---|---|---|
| MAETok + SiT-XL | AE | 675M | 128 | 2.31 | 1.67 |
| MAETok + LightningDiT | AE | 675M | 128 | 2.21 | 1.73 |
| REPA + SiT-XL | KL | — | 256 | 5.90 | 1.42 |
| LightningDiT | KL | 675M | 256 | 2.17 | 1.35 |
| DiT-XL/2 | — | 675M | — | 9.62 | 2.27 |
| 模型 | Token 数 | gFID (无CFG) | gFID (有CFG) | IS (有CFG) ↑ |
|---|---|---|---|---|
| MAETok + SiT-XL | 128 | 2.79 | 1.69 | 304.2 |
| MAETok + LightningDiT | 128 | 2.56 | 1.72 | 224.5 |
| MAETok + USiT-2B | 128 | 1.72 | 1.65 | 282.3 |
| USiT-2B(原文) | 256 | 3.50 | 2.43 | 234.8 |
| MAR-H (943M) | 1024 | 2.74 | 1.73 | 279.9 |
| DiT-XL/2 | — | 9.62 | 3.04 | 240.8 |
MAETok + SiT-XL 仅用 128 token 超越 2B 参数 USiT(256 token)和 943M MAR-H(1024 token)。在 CFG 条件下达到 gFID 1.69,IS 304.2,为当前 SOTA。
| Tokenizer | Token 数 | GFLOPs | Throughput (img/s) |
|---|---|---|---|
| SD-VAE(1024 tokens) | 1024 | 373.3 | 0.1 |
| MAETok(128 tokens) | 128 | 48.5 | 3.12 |
128 tokens 相比标准 1024 tokens,FLOPs 降低 7.7×,推理吞吐提升 31×。训练速度提升 76×(达到同等 REPA 性能)。
实践启示:在资源受限场景(移动端、边缘设备)下,MAETok 的极致压缩率(512×512 → 128 tokens)使得实时高分辨率生成成为可能。
Figure 4:AE 和 VAE 的潜空间不同类别大量重叠,而 MAETok 展现出清晰分离的聚类结构,类间边界分明。
Figure 5:Linear Probing 准确率与 gFID 高度相关 — 更判别性的潜空间 → 更好的生成质量。
- 理论:潜在分布的 GMM mode 越少,扩散模型所需训练样本越少($\sim K^4$),生成质量越好
- 方法:MAE 的 mask modeling 能在 plain AE 上学到判别性潜在空间,不需要 VAE 的变分约束
- 解耦效应:encoder 学判别性表征 和 decoder 学高保真重建可以分开——先 mask 训练 encoder,再 fine-tune decoder
- 实践:128 tokens 的 MAETok + 675M SiT-XL 在 ImageNet 512×512 上达到 gFID 1.69 的 SOTA
局限性与未来方向
- 辅助解码器引入了额外计算开销,虽然论文称其 overhead 很小,但未给出量化数据
- 两阶段训练的调参空间较大(Stage 1 mask ratio、Stage 2 fine-tune 步数),泛化性未充分验证
- 在非 ImageNet 数据集(如 COCO、ADE20K)上的生成质量未报告
- 与更强 diffusion 骨干(如 DiT-XL/2 + classifier-free guidance tuning)的联合优化空间未被探索
- 分析框架:GMM mode 数提供了一个量化评估任意 tokenizer 潜空间质量的新视角,linear probing accuracy 可作为便捷的 proxy metric
- 设计思路:对于视觉 tokenizer 的后续设计,可以直接采用"plain AE + MAE mask modeling"的范式,抛弃 VAE 的变分约束
- 两阶段训练:当 encoder 和 decoder 的目标冲突时(判别性 vs 重建精度),解耦训练是有效的解决方案
- 多目标学习:辅助解码器用不同层次的特征(HOG/DINOv2/SigCLIP)引导潜空间学习,比单一目标更有效
- 极致压缩:128 tokens 超越 256-1024 tokens 的效果,说明 token 数量不是越多越好,关键是潜空间的判别性结构