ELF: Embedded Language Flows
ELF(Embedded Language Flows)是 MIT 何恺明与 Jacob Andreas 团队于 2026 年 5 月提出的连续扩散语言模型。与当前主流的离散扩散模型(如 MDLM、Duo)不同,ELF 将去噪过程完全保留在连续的 embedding 空间中,只在最后一步(t=1)做 discretization。
- 全程驻留连续空间(t ∈ (0,1)),仅在 t=1 一步离散化
- Shared-weight network:同一网络同时承担去噪和最终解码,无需独立 decoder
- x-prediction 参数化:直接预测 clean embedding 而非 velocity,使 t→1 时梯度稳定
- 训练时 CFG 实现:无需额外 forward pass,通过 batch 中的条件/无条件预测计算 guidance target
- T5 frozen encoder embeddings:只用训练时,推理无额外模块
- arXiv:2605.10938
- PDF:PDF
- 代码:github.com/lillian039/ELF(JAX/TPU)
- HuggingFace Checkpoints:embedded-language-flows
- 机构:MIT CSAIL
扩散语言模型的两条路线
| 方法 | 代表工作 | 去噪空间 | Discretization |
|---|---|---|---|
| 离散 DLM | MDLM, Duo | token space | 每步都是离散的 |
| 连续 DLM | CDCD, DiffuSeq | embedding space | 每步有 CE supervision |
| Flow-based DLM | LangFlow, FLM | embedding space | 每步有 CE supervision |
| 本文 ELF | ELF | embedding space | 仅最后一步 |
连续 DLM 长期落后于离散方法,根源不在于语言"本质离散",而在于之前方法在训练中对中间状态施加 token-level 监督(cross-entropy loss),打断了 Flow Matching 连续轨迹的灵活性。Flow Matching(Albergo & Vanden-Eijnden, 2023)的出现提供了理论工具,使连续 embedding 扩散能够真正发挥效力。参见 连续扩散语言模型路线综述 了解更多背景。
- Flow Matching:连续时间扩散的统一框架。线性插值 $z_t = t \cdot x + (1-t) \cdot \varepsilon$,速度场 $v = x - \varepsilon$。
- Rectified Flow:使用线性插值路径,相比 SDE/ODE 更简洁。
- x-prediction:直接预测 clean embedding 而非 velocity $v$,t→1 时梯度不会消失。
- Shared-weight Network:同一网络同时承担去噪和解码,无需独立 decoder。
- Self-conditioning:用模型自身的上一次预测作为额外条件,最早来自 Analog Bits(Chen et al., 2023)。
- CFG:条件与无条件预测的线性组合增强生成质量,在连续空间原生支持。
之前连续 DLM 在每步都施加 token-level cross-entropy supervision,等于把连续流形强制"拽回"到离散 token 空间。这破坏了 Flow Matching 最核心的优势——在连续空间自由探索最优轨迹的能力。
ELF 只在最后一步做 discretization,去噪全程在连续空间完成,流动态获得最大灵活性。
解码可以内嵌到去噪网络,不需要独立 decoder。embedding 空间到 token 空间的映射,本质上是一个线性投影(unembedding matrix),与去噪过程共享参数即可。unembedding 权重与 embedding 权重绑定(tied weights),最终时间步的输出直接通过这个线性映射得到 token logits,不需要额外的神经网络 decoder。
x-prediction > v-prediction 在语言建模中至关重要。当 $t \to 1$ 时,$v = (x - z_t)/(1-t)$ 中分子分母同时趋于 0,梯度消失。x-prediction 在整个时间域内保持稳定梯度,而且与最终 discretization 天然一致。
CFG 在连续空间几乎零成本实现。图像扩散中 CFG 需要两个 forward pass。ELF 通过 batch 内计算技巧,在训练时同时得到条件和无条件预测,无需额外的网络前向传播。
输入 token 序列 → T5 frozen encoder → continuous embeddings → ELF 去噪(连续时间 Flow Matching)→ 最后一 Step discretization → 输出 token 序列
整体架构:去噪和解码由同一个神经网络完成,通过 decoder_step_active 标志区分。
ELFBlock:Transformer 块
代码路径:~/code/elf/src/modules/model.py
class ELFBlock(nn.Module):
hidden_size: int
num_heads: int
mlp_ratio: float = 4.0
def __call__(self, x, rope_fn=None, attention_mask=None, deterministic=True):
x_normed = RMSNorm(self.hidden_size, eps=1e-6)(x)
attn_out = Attention(
self.hidden_size, self.num_heads,
qkv_bias=True, qk_norm=True, # QK-Norm 稳定训练
name='attn',
)(x_normed, rope_fn, attention_mask=attention_mask)
x = x + attn_out
x_normed = RMSNorm(self.hidden_size, eps=1e-6)(x)
mlp_out = SwiGLUFFN(self.hidden_size, mlp_hidden_dim)(x_normed)
x = x + mlp_out
return x
关键设计:qk_norm=True(RMSNorm 在 QKV 投影前)+ SwiGLU FFN(门控线性单元),结构接近 Llama/LLaMA 族的设计。
Context Prefix:时间条件和自条件注入
ELF 不使用 AdaLN / timestep embedding 直接注入,而是通过 prefix token 注入时间信息和自条件 CFG scale:
def build_context(self, t, self_cond_cfg_scale=None):
prefix_tokens = []
# 时间步 tokens:可学习向量 + 时间步嵌入
time_emb = TimestepEmbedder(self.hidden_size)(t)
prefix_tokens.append(time_emb) # num_time_tokens=4
# 自条件 CFG tokens:可学习向量 + CFG scale 嵌入
if self_cond_cfg_scale is not None:
sc_emb = TimestepEmbedder(self.hidden_size)(self_cond_cfg_scale)
prefix_tokens.append(sc_emb) # num_self_cond_cfg_tokens=4
return prefix_tokens # 总共 8 个 prefix tokens
这种设计让模型通过 prefix 中的 token 来"感知"当前时间步和 CFG 强度。
BottleneckTextProj:维度压缩与扩张
T5 encoder 输出 512d embedding,ELF Transformer hidden dim 是 768d,通过 BottleneckTextProj 实现维度转换(512 → 128 → 768)。瓶颈层迫使网络学习压缩的中间表示,128d 是关键设计——论文消融显示直接 512→768 产生极高 perplexity。
Dual-branch 训练:去噪 vs 解码
每次训练 step 中,以 decoder_prob 概率(默认 0.2)激活 decoder branch,以 1-decoder_prob 激活 denoiser branch:
- Decoder mode(20%):t=1,CE loss on tokens。对 x0 施加 logit-normal 噪声,提供有意义的训练信号。
- Denoiser mode(80%):random t,L2 loss on velocity。
CFG 训练时的 Label Drop
训练时以 label_drop_prob 概率将条件 token 的 embedding 清零,防止 target tokens attend to conditioning tokens。这确保了模型在 CFG 引导时真正学习无条件表示。
优化器:Muon
| 参数 | 值 |
|---|---|
| 优化器 | Muon |
| base lr (blr) | 0.001 |
| effective lr | blr × batch_size / 256 = 0.002(batch=512 时) |
| global batch size | 512 × 1024 tokens |
| 训练 epochs | 5 epochs OWT(~95K steps) |
EMA 更新策略
EMA 只在实际的 optimizer step(而非 gradient accumulation step)更新,避免 effective decay 变成 decay^grad_accum_steps。
SDE 采样:Hybrid 噪声缩放
ELF 的 SDE 采样采用 hybrid 设计(不是标准 DDPM 的固定量噪声):
def _sde_step(..., gamma, rng):
h = t_next - t
alpha = jnp.clip(1.0 - gamma * h, 0.0, 1.0) # 信号保留比例
t_back = alpha * t # 回到的时间点
# 加噪声回到 t_back
eps = jax.random.normal(rng, z.shape) * denoiser_noise_scale
z_back = alpha * z + (1.0 - alpha) * eps
# 前向传播
v_pred, x_pred = _forward_sample(..., z_back, t_back, ...)
# 从 z_back 走到 t_next
return z_back + (t_next - t_back) * v_pred
gamma 控制噪声注入强度。gamma=0 时退化为 ODE(Euler);gamma=1.0 时为默认 SDE;极少步数(4-8 步)时 gamma=2.0 增加随机性。
Self-conditioning 在采样中的应用
采样时使用上一步的 x_pred 作为 self-conditioning context,条件和无条件预测通过 batch 内计算获得,不需要两次完整 forward pass。
CFG 尺度从 Log-uniform 分布采样
训练时 CFG scale 从 log-uniform 分布采样,而非固定值:
def sample_cfg_scale(rng, batch_size, cfg_min=0.0, cfg_max=3.0):
u = jax.random.uniform(rng, (batch_size,))
a = jnp.float32(1.0 + cfg_min) # 1.5
b = jnp.float32(1.0 + cfg_max) # 4.0
return a * jnp.exp(u * jnp.log(b / a)) - 1.0 # 范围 [0.5, 3.0]
log-uniform 采样使 CFG 尺度在 [0, cfg_max] 范围内更均匀分布。
推理的两阶段生成流程
- Denoising 阶段:从 Gaussian 噪声 z_0 开始,通过 SDE/ODE 迭代走到 t=1,得到 clean embedding
- Decoding 阶段:在 t=1 用 decoder_branch 输出 token logits,通过 argmax/softmax 解码为离散 token
主实验:无条件生成(OWT 数据集)
| 模型 | 参数量 | Gen PPL ↓ | 采样步数 | 训练 tokens |
|---|---|---|---|---|
| MDLM | 170M | ~28 | 100+ | ~524B |
| Duo | 170M | ~26 | 100+ | ~524B |
| FLM | ~170M | ~27 | 100+ | ~500B+ |
| LangFlow | ~170M | ~26 | 100+ | ~500B+ |
| ELF-B | 105M | 24 | 32 | 45.2B |
ELF-B 105M 在 32 步采样内达到 Gen PPL 24,超越了 170M 的离散和连续 DLM,且训练 tokens 减少约 11 倍。
条件生成
| 任务 | 数据集 | BLEU / ROUGE-1/2/L |
|---|---|---|
| 机器翻译 | WMT14 De-En | 26.4 |
| 摘要 | XSum | 36.0 / 12.2 / 27.8 |
消融实验关键发现
| 消融项 | 结论 |
|---|---|
| x-prediction vs v-prediction | x-prediction 在 768d 上稳定,v-prediction 梯度消失 |
| CFG scale 最佳点 | 3.0 最优,>3 后 perplexity 回升 |
| SDE vs ODE | SDE 在少步数(8-16)时显著优于 ODE |
| Shared-weight vs two-stage decoder | 整体 trade-off 相近,shared-weight 在低 PPL 区域更优 |
| 128d bottleneck | 必须;512d 直接 projection 产生极高 perplexity |
- 去噪轨迹的完整性 > 每步精确度:ELF 证明让连续流形保持完整比在每步精确预测 token 更重要。类比:在不知道目的地时,沿着正确的坡度滑行比在错误的地点跳到精确坐标更好。
- Interface 设计决定上限:Diffusion-LM 和 CDCD 的失败不是因为"连续扩散做不好语言",而是在每步都重新"翻译"到离散空间打断了流形。ELF 将接口最小化(只一次),LangFlow 试图优化接口(每步 CE),这是两条不同的解决思路。
- x-prediction 在高维连续空间的重要性:标准 Flow Matching 用 v-prediction 在图像生成中 work,但在语言的高维 embedding 空间(768d)中,t→1 时的梯度消失问题变得严重。x-prediction 直接在数据流形上预测,规避了这个问题。
- "无 decoder"不是没有 decoder,而是 decoder 已隐含在线性映射中:这是 tied embedding/unembedding weights 的数学必然。需要 T5 encoder 的 embedding 和 ELF 网络的 hidden dim 兼容(512 → 128 bottleneck → 768)。
- ByteDance Cola DLM vs MIT ELF 的路线对比:Cola DLM 用 Text VAE + block-causal DiT,走"分层解耦"路线;ELF 用 frozen encoder + shared-weight network,走"最简适配"路线。两者共同点是都认为连续 embedding 是正确方向,分歧在于如何处理 embedding 空间和 token 空间的接口。
- Muon 优化器在扩散语言模型中的有效性:ELF 第一个使用 Muon 优化器替代 AdamW,通过 Gram-Schmidt 正交化稳定训练,适合 Transformer 架构。
- SDE/ODE hybrid 采样的实用价值:gamma 参数控制在 [0, 2] 范围内,可以根据采样步数自适应调整——少步数用高 gamma 增加随机性,多步数用低 gamma 接近确定性。