ESC
输入关键词搜索文章
目录

ELF: Embedded Language Flows

arXiv 2605.10938 · MIT CSAIL · 2026
105M 参数干翻离散 DLM:去噪全程驻留连续空间,仅最后一步离散化
简介

ELF(Embedded Language Flows)是 MIT 何恺明与 Jacob Andreas 团队于 2026 年 5 月提出的连续扩散语言模型。与当前主流的离散扩散模型(如 MDLM、Duo)不同,ELF 将去噪过程完全保留在连续的 embedding 空间中,只在最后一步(t=1)做 discretization。

核心结论:105M 参数的 ELF-B 在 32 步采样内达到 Gen PPL 24,优于 170M 的离散扩散模型,且训练 tokens 减少 10 倍。这不是通过更大的模型规模实现的,而是通过更简洁的架构设计——shared-weight network、无独立 decoder、无每步 token-level supervision。
  • 全程驻留连续空间(t ∈ (0,1)),仅在 t=1 一步离散化
  • Shared-weight network:同一网络同时承担去噪和最终解码,无需独立 decoder
  • x-prediction 参数化:直接预测 clean embedding 而非 velocity,使 t→1 时梯度稳定
  • 训练时 CFG 实现:无需额外 forward pass,通过 batch 中的条件/无条件预测计算 guidance target
  • T5 frozen encoder embeddings:只用训练时,推理无额外模块
参考链接
相关工作

扩散语言模型的两条路线

方法代表工作去噪空间Discretization
离散 DLMMDLM, Duotoken space每步都是离散的
连续 DLMCDCD, DiffuSeqembedding space每步有 CE supervision
Flow-based DLMLangFlow, FLMembedding space每步有 CE supervision
本文 ELFELFembedding space仅最后一步

连续 DLM 长期落后于离散方法,根源不在于语言"本质离散",而在于之前方法在训练中对中间状态施加 token-level 监督(cross-entropy loss),打断了 Flow Matching 连续轨迹的灵活性。Flow Matching(Albergo & Vanden-Eijnden, 2023)的出现提供了理论工具,使连续 embedding 扩散能够真正发挥效力。参见 连续扩散语言模型路线综述 了解更多背景。

前置知识
  • Flow Matching:连续时间扩散的统一框架。线性插值 $z_t = t \cdot x + (1-t) \cdot \varepsilon$,速度场 $v = x - \varepsilon$
  • Rectified Flow:使用线性插值路径,相比 SDE/ODE 更简洁。
  • x-prediction:直接预测 clean embedding 而非 velocity $v$,t→1 时梯度不会消失。
  • Shared-weight Network:同一网络同时承担去噪和解码,无需独立 decoder。
  • Self-conditioning:用模型自身的上一次预测作为额外条件,最早来自 Analog Bits(Chen et al., 2023)。
  • CFG:条件与无条件预测的线性组合增强生成质量,在连续空间原生支持。
Insight:去噪轨迹的完整性决定生成质量

之前连续 DLM 在每步都施加 token-level cross-entropy supervision,等于把连续流形强制"拽回"到离散 token 空间。这破坏了 Flow Matching 最核心的优势——在连续空间自由探索最优轨迹的能力。

ELF 只在最后一步做 discretization,去噪全程在连续空间完成,流动态获得最大灵活性。

类比:图像扩散从来不在中间步骤"四舍五入到像素值",语言扩散也不应该每步"四舍五入到 token"。

解码可以内嵌到去噪网络,不需要独立 decoder。embedding 空间到 token 空间的映射,本质上是一个线性投影(unembedding matrix),与去噪过程共享参数即可。unembedding 权重与 embedding 权重绑定(tied weights),最终时间步的输出直接通过这个线性映射得到 token logits,不需要额外的神经网络 decoder。

x-prediction > v-prediction 在语言建模中至关重要。$t \to 1$ 时,$v = (x - z_t)/(1-t)$ 中分子分母同时趋于 0,梯度消失。x-prediction 在整个时间域内保持稳定梯度,而且与最终 discretization 天然一致。

CFG 在连续空间几乎零成本实现。图像扩散中 CFG 需要两个 forward pass。ELF 通过 batch 内计算技巧,在训练时同时得到条件和无条件预测,无需额外的网络前向传播。

模型架构(代码解析)
ELF Conceptual Illustration
图 1:ELF 概念示意图。橙色点表示连续 embedding 空间中的数据,紫色线表示从高斯噪声到干净 embedding 的去噪轨迹。Discretization 仅在最终时间步 t=1 应用(来源:ELF GitHub)。

输入 token 序列 → T5 frozen encoder → continuous embeddings → ELF 去噪(连续时间 Flow Matching)→ 最后一 Step discretization → 输出 token 序列

整体架构:去噪和解码由同一个神经网络完成,通过 decoder_step_active 标志区分。

ELFBlock:Transformer 块

代码路径:~/code/elf/src/modules/model.py

class ELFBlock(nn.Module):
    hidden_size: int
    num_heads: int
    mlp_ratio: float = 4.0

    def __call__(self, x, rope_fn=None, attention_mask=None, deterministic=True):
        x_normed = RMSNorm(self.hidden_size, eps=1e-6)(x)
        attn_out = Attention(
            self.hidden_size, self.num_heads,
            qkv_bias=True, qk_norm=True,  # QK-Norm 稳定训练
            name='attn',
        )(x_normed, rope_fn, attention_mask=attention_mask)
        x = x + attn_out

        x_normed = RMSNorm(self.hidden_size, eps=1e-6)(x)
        mlp_out = SwiGLUFFN(self.hidden_size, mlp_hidden_dim)(x_normed)
        x = x + mlp_out
        return x

关键设计:qk_norm=True(RMSNorm 在 QKV 投影前)+ SwiGLU FFN(门控线性单元),结构接近 Llama/LLaMA 族的设计。

Context Prefix:时间条件和自条件注入

ELF 不使用 AdaLN / timestep embedding 直接注入,而是通过 prefix token 注入时间信息和自条件 CFG scale:

def build_context(self, t, self_cond_cfg_scale=None):
    prefix_tokens = []
    # 时间步 tokens:可学习向量 + 时间步嵌入
    time_emb = TimestepEmbedder(self.hidden_size)(t)
    prefix_tokens.append(time_emb)  # num_time_tokens=4

    # 自条件 CFG tokens:可学习向量 + CFG scale 嵌入
    if self_cond_cfg_scale is not None:
        sc_emb = TimestepEmbedder(self.hidden_size)(self_cond_cfg_scale)
        prefix_tokens.append(sc_emb)  # num_self_cond_cfg_tokens=4

    return prefix_tokens  # 总共 8 个 prefix tokens

这种设计让模型通过 prefix 中的 token 来"感知"当前时间步和 CFG 强度。

BottleneckTextProj:维度压缩与扩张

T5 encoder 输出 512d embedding,ELF Transformer hidden dim 是 768d,通过 BottleneckTextProj 实现维度转换(512 → 128 → 768)。瓶颈层迫使网络学习压缩的中间表示,128d 是关键设计——论文消融显示直接 512→768 产生极高 perplexity。

训练细节(代码解析)

Dual-branch 训练:去噪 vs 解码

每次训练 step 中,以 decoder_prob 概率(默认 0.2)激活 decoder branch,以 1-decoder_prob 激活 denoiser branch:

  • Decoder mode(20%):t=1,CE loss on tokens。对 x0 施加 logit-normal 噪声,提供有意义的训练信号。
  • Denoiser mode(80%):random t,L2 loss on velocity。

CFG 训练时的 Label Drop

训练时以 label_drop_prob 概率将条件 token 的 embedding 清零,防止 target tokens attend to conditioning tokens。这确保了模型在 CFG 引导时真正学习无条件表示。

优化器:Muon

参数
优化器Muon
base lr (blr)0.001
effective lrblr × batch_size / 256 = 0.002(batch=512 时)
global batch size512 × 1024 tokens
训练 epochs5 epochs OWT(~95K steps)

EMA 更新策略

EMA 只在实际的 optimizer step(而非 gradient accumulation step)更新,避免 effective decay 变成 decay^grad_accum_steps。

采样详解(代码解析)
ELF Denoising Trajectory
图 2:ELF-B 的去噪轨迹。随着 t 从 0 增加到 1,不合语法的句子逐渐被精化为流畅、语法正确的文本(来源:ELF GitHub)。

SDE 采样:Hybrid 噪声缩放

ELF 的 SDE 采样采用 hybrid 设计(不是标准 DDPM 的固定量噪声):

def _sde_step(..., gamma, rng):
    h = t_next - t
    alpha = jnp.clip(1.0 - gamma * h, 0.0, 1.0)  # 信号保留比例
    t_back = alpha * t  # 回到的时间点

    # 加噪声回到 t_back
    eps = jax.random.normal(rng, z.shape) * denoiser_noise_scale
    z_back = alpha * z + (1.0 - alpha) * eps

    # 前向传播
    v_pred, x_pred = _forward_sample(..., z_back, t_back, ...)

    # 从 z_back 走到 t_next
    return z_back + (t_next - t_back) * v_pred

gamma 控制噪声注入强度。gamma=0 时退化为 ODE(Euler);gamma=1.0 时为默认 SDE;极少步数(4-8 步)时 gamma=2.0 增加随机性。

Self-conditioning 在采样中的应用

采样时使用上一步的 x_pred 作为 self-conditioning context,条件和无条件预测通过 batch 内计算获得,不需要两次完整 forward pass。

CFG 尺度从 Log-uniform 分布采样

训练时 CFG scale 从 log-uniform 分布采样,而非固定值:

def sample_cfg_scale(rng, batch_size, cfg_min=0.0, cfg_max=3.0):
    u = jax.random.uniform(rng, (batch_size,))
    a = jnp.float32(1.0 + cfg_min)   # 1.5
    b = jnp.float32(1.0 + cfg_max)   # 4.0
    return a * jnp.exp(u * jnp.log(b / a)) - 1.0  # 范围 [0.5, 3.0]

log-uniform 采样使 CFG 尺度在 [0, cfg_max] 范围内更均匀分布。

推理的两阶段生成流程

  1. Denoising 阶段:从 Gaussian 噪声 z_0 开始,通过 SDE/ODE 迭代走到 t=1,得到 clean embedding
  2. Decoding 阶段:在 t=1 用 decoder_branch 输出 token logits,通过 argmax/softmax 解码为离散 token
实验
ELF System Comparison
图 3:ELF 系统级对比。(a) 在相似训练设置下,ELF-B 优于离散和连续 DLM;(b) 与需要额外蒸馏的 baseline 蒸馏变体相当;(c) 训练 tokens 少一个数量级(来源:ELF GitHub)。

主实验:无条件生成(OWT 数据集)

模型参数量Gen PPL ↓采样步数训练 tokens
MDLM170M~28100+~524B
Duo170M~26100+~524B
FLM~170M~27100+~500B+
LangFlow~170M~26100+~500B+
ELF-B105M243245.2B

ELF-B 105M 在 32 步采样内达到 Gen PPL 24,超越了 170M 的离散和连续 DLM,且训练 tokens 减少约 11 倍。

条件生成

任务数据集BLEU / ROUGE-1/2/L
机器翻译WMT14 De-En26.4
摘要XSum36.0 / 12.2 / 27.8

消融实验关键发现

消融项结论
x-prediction vs v-predictionx-prediction 在 768d 上稳定,v-prediction 梯度消失
CFG scale 最佳点3.0 最优,>3 后 perplexity 回升
SDE vs ODESDE 在少步数(8-16)时显著优于 ODE
Shared-weight vs two-stage decoder整体 trade-off 相近,shared-weight 在低 PPL 区域更优
128d bottleneck必须;512d 直接 projection 产生极高 perplexity
收获
  1. 去噪轨迹的完整性 > 每步精确度:ELF 证明让连续流形保持完整比在每步精确预测 token 更重要。类比:在不知道目的地时,沿着正确的坡度滑行比在错误的地点跳到精确坐标更好。
  2. Interface 设计决定上限:Diffusion-LM 和 CDCD 的失败不是因为"连续扩散做不好语言",而是在每步都重新"翻译"到离散空间打断了流形。ELF 将接口最小化(只一次),LangFlow 试图优化接口(每步 CE),这是两条不同的解决思路。
  3. x-prediction 在高维连续空间的重要性:标准 Flow Matching 用 v-prediction 在图像生成中 work,但在语言的高维 embedding 空间(768d)中,t→1 时的梯度消失问题变得严重。x-prediction 直接在数据流形上预测,规避了这个问题。
  4. "无 decoder"不是没有 decoder,而是 decoder 已隐含在线性映射中:这是 tied embedding/unembedding weights 的数学必然。需要 T5 encoder 的 embedding 和 ELF 网络的 hidden dim 兼容(512 → 128 bottleneck → 768)。
  5. ByteDance Cola DLM vs MIT ELF 的路线对比:Cola DLM 用 Text VAE + block-causal DiT,走"分层解耦"路线;ELF 用 frozen encoder + shared-weight network,走"最简适配"路线。两者共同点是都认为连续 embedding 是正确方向,分歧在于如何处理 embedding 空间和 token 空间的接口。
  6. Muon 优化器在扩散语言模型中的有效性:ELF 第一个使用 Muon 优化器替代 AdamW,通过 Gram-Schmidt 正交化稳定训练,适合 Transformer 架构。
  7. SDE/ODE hybrid 采样的实用价值:gamma 参数控制在 [0, 2] 范围内,可以根据采样步数自适应调整——少步数用高 gamma 增加随机性,多步数用低 gamma 接近确定性。