ESC
输入关键词搜索文章
目录

ProgressiveDiTok v0.2-对扩散模型再理解——怎么处理条件信息?

加噪与去噪

扩散模型在最开始,是一个去噪模型。它的核心过程有两个,一个是加噪过程,一个是去噪过程。

我们可以把加噪过程写成以下的形式:

\(p(\mathrm x_t\mid\mathrm x_0)\)

其中:

$$p_\theta(\mathbf{x}_{t-\Delta t} \mid \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-\Delta t}; \mu_\theta(\mathbf{x}_t, t), \Sigma_\theta(\mathbf{x}_t, t))$$

由于模型能够理解图像的真实分布 \(p(\mathrm x_0)\), 所以模型每一步去噪以后,都可以对图像当前的位置进行判断,从而了解到在当前位置,要向什么方向前进才能更接近于真实图像。

这个过程可以表示为:

\(p_{\theta,\mathrm x_0\sim p( \mathrm x_0) }(\mathrm x_{t-\Delta t}\mid \mathrm x_t)\)

这就相当于在黄金矿工里面,你每一次下钩以后,你的钩子只前进 1 米,然后在新的位置,你的钩子再次开始摇摆,然后你可以再次发射钩子。很显然,这使得我们的模型更容易到达目的地。

再举一个例子,这就相当于我们有一个龙珠探测器,可以用于寻找龙珠,它离龙珠越近,它的信号越强烈。自然是需要边找边看龙珠探测器的信号有没有变强。

具体如何从 \(\mathrm x_t\) 去噪得到 \(\mathrm x_{t-\Delta t}\), 可以参考以下的常微分方程(ODE),本文主要提供一个框架上的推导,不对具体公式做解释:

$$dx_t = -2\dot{\sigma}_t\sigma_t \nabla_{x_t} \log p(x_t; \sigma_t) dt + \sqrt{2\dot{\sigma}_t\sigma_t} dw_t$$

这个公式来自于 EDM 论文。

条件信息

这里有一个问题:我们只知道 \(\mathrm x_t\)\(p(\mathrm x_0)\), 但我们居然不知道 \(x_0\) 本身的任何信息。当然, \(\mathrm x_t\) 本身就带有部分 \(\mathrm x_0\) 的信息。

按流形假设,分布 \(p(\mathrm x_0)\) 是一个流形,流形中距离 \(\mathrm x_t\) 最近的点自然最有可能是 \(\mathrm x_0\).

但是,我们在去噪过程中毕竟不知道 \(\mathrm x_0\) 本身,这使得我们最终的目的地是不可知的,它可能很像真实图像,但是它最终大概率只是 \(p(\mathrm x_0)\) 上距离 \(\mathrm x_0\) 比较接近的一个点:它很真实,它很像 \(\mathrm x_0\), 但它不是 \(\mathrm x_0\).

既然我们的目标是要重建出 \(\mathrm x_0\), 那么最简单的一点:我们直接把 \(\mathrm x_0\) 当成条件信息传给模型,就有:

\(p_\theta(\mathrm x_{t-\Delta t}\mid \mathrm x_t, \mathrm x_0)\)

这个逻辑是清晰的,模型知道 \(x_0\) 再从 \(x_t\) 中逐步重建出 \(x_0\) ……

很好,但有什么用呢?简直就是脱裤子放屁,模型学会了一个恒等映射。

这个模型自然是没有用的,但是有了它以后,方便我们理解后面的过程。

在真实任务中,我们虽然不知道 \(\mathrm x_0\), 但我们总能知道有关于 \(\mathrm x_0\) 的某些信息,我们将这些信息统称为条件信息。

在文生图任务中,我们提供的文本是这个条件信息;图像重建任务中,损坏的图像可以提供这个条件信息;图像编辑任务中,修改指令可以提供这个条件信息。

我们用 \(y\) 来表示这个条件信息。这个符合使用与 DiT 保持了一致。

有了\(y\) 以后,我们把扩散模型工作的过程重新表达为:

$$p(\mathrm x_0 \mid \mathrm y) \propto p(\mathrm y \mid \mathrm x_0) p(\mathrm x_0)$$

其中, \(\propto\) 右侧表达从 \(\mathrm x_0\) 中提取出条件信息 \(\mathrm y\)

\(\mathrm y\) 可以表示为 \(\mathrm y=\mathcal A(\mathrm x_0)+\mathrm n\)其中, \(\mathcal A\) 是观测过程, \(\mathrm n\sim \mathcal N(0, \beta^2_y \mathrm I)\)

\(p(\mathrm y\mid\mathrm x_0)=\mathcal N(\mathcal A(\mathrm x_0),\beta^2_y\mathrm I)\)

这是一个逆问题。

逆问题要从部分的、带噪的观测(measurement) \(y\) 中恢复出原始图像 \(x_0\)

我们把上面的表述再说一遍,方便读者理解:在文生图任务中,\(\mathrm y\)我们提供的文本;图像重建任务中,\(\mathrm y\) 是损坏的图像;图像编辑任务中,\(\mathrm y\) 是修改指令。

对于图像编辑任务,这个逆问题完全可以表示成:

\(p_\theta(\mathrm x_{t-\Delta t}\mid \mathrm x_t, \mathrm x_0, \mathrm y)\)

不过大多的逆问题依然不知道 \(\mathrm x_0\), 因此我们也将这个过程表达为两步的过程。因此,我们也就能很顺畅地将图像的重建过程划分出编码过程与解码过程。

条件提取过程(编码器过程):

$$p_{\theta_\text{enc}}(\mathrm y \mid \mathrm x_0)$$

去噪过程(解码器过程):

对于每一个 \(t\) ,我们有去噪过程:

$$p_{\theta_\text{dec}}(\mathrm x_{t-\Delta t}\mid \mathrm x_t, \mathrm y)$$

因为多了一步从 \(\mathrm x_0\) 中进行观测,从而得到 \(\mathrm y\), 再从 \(\mathrm y\) 中恢复出 \(\mathrm x_0\) 的过程,我们也就多了一个条件 \(\mathrm y\).

我们的条件信息从 \(p(\mathrm x_t)\) 变成了 \(p(\mathrm x_t, \mathrm y)=p(\mathrm x_t)p(\mathrm y\mid \mathrm x_t)\)

由此,我们对第一节得到的方程进行修改:

$$\text d\mathrm x_t = -2\dot{\sigma}_t\sigma_t \nabla_{\mathrm x_t} \log p(\mathrm x_t; \sigma_t) \text dt -2\dot{\sigma}_t\sigma_t \nabla_{\mathrm x_t} \log p(\mathrm y |\mathrm x_t) \text dt+ \sqrt{2\dot{\sigma}_t\sigma_t} \text dw_t$$

因为条件信息的作用是替代 \(\mathrm x_0\), 提供原始信息,在不同的扩散时间步中,始终给模型提供一个目标指导,所以一般来说,不同的扩散时间步中, \(y\) 的信息是不会变的。

随时间步变化的条件信息: \(y\rightarrow y_t\)

在上一节中,我们的条件信息是与时间无关的,这其实有一些死板,这意味着,无论什么时候,条件信息所表示的都是 \(\mathrm x_0\) 的结构信息。

也就是\(\mathrm y=\mathcal A(\mathrm x_0)+\mathrm n\). 且 \(\mathrm n\sim \mathcal N(0, \beta^2_y \mathrm I)\)

如果 \(\mathrm y\) 本身就是 \(\mathrm x_0\), 也就是说没有信息损失, \(\beta_\mathrm y^2\) 为 0,这没有什么影响,因为它始终可以提供完整的目标图像信息。

但由于 \(\mathrm y\) 并不能提供 \(\mathrm x_0\) 中完整的信息,这就表示,我们虽然显式地、始终如一地向扩散模型提供了原始图像的信息指导,可是这种信息指导始终是带噪的,而且这种噪声并没有随着去噪过程进行而被消解。

可是去噪过程随着去噪接近了目标,其实会提供出新的信息,也就是说,如果 \(\mathrm y\) 能充分利用前一步去噪的结果,对 \(\mathrm y_t\) 进行更新,其实可以让 \(\mathrm y\) 本身也实现去噪。

但是和 \(\mathrm x_t\) 一样,在去噪的过程中,\(\mathrm y\) 的信息会损失,我们不希望有这种信息损失,那怎么办呢……没错,把 \(\mathrm y\) 本身也保存。因此我们有:

\(p_\theta(\mathrm y_{t}\mid\mathrm y,\mathrm y_{t+\Delta t}, \mathrm x_t, t)\)

我们之前假设了真实向模型中传递 \(\mathrm x_0\) 的情况,其过程也可以表达为: \(p_\theta(\mathrm x_{t}\mid \mathrm x_0, \mathrm x_{t+\Delta t}, \mathrm y, t)\)

这个式子和我们更新 \(y\) 的过程有类似之处。

(加不加 \(t\) 是一样的,因为 \(\mathrm x_t\) 里面隐式包含了 \(t\). )

我们可以一起看一下 DiT 的结构:

标签 \(y\) 在每一轮去噪都是相同的,修改的是时间步 \(t\).

有:

\(p(\mathrm y_t\mid y, t)\)

但是时间步 \(t\) 中其实没有任何有效信息,因此 \(\mathrm y\) 没有被进行有效的信息更新,只是多传递了 \(t\) 的信息而已。当然,在 DiT Block 中,依然会隐式地对 \(\mathrm y_t\) 进行建模,但这种隐式解模的效果不一定好。至少从 DDT 的研究中可以得出,这种建模的效果其实不行。

再看 DDT:

DDT 解耦了 Condition Encoder 与 Velocity Decoder,

其中 Condition Encoder 所做的工作就是:

\(p_\theta(\mathrm y_{t}\mid \mathrm x_t, \mathrm y, t)\)

没有传递 \(\mathrm y_{t+\Delta t}\), 不过这实际上影响不大。首先,扩散模型的结构束缚,导致 DDT 其实不太方便传入 \(\mathrm y_{t+\Delta t}\), 其次,Velocity Decoder 的工作是 \(p_\theta(\mathrm x_t \mid \mathrm x_{t+\Delta t}, \mathrm y_{t+\Delta t}, t)\) 因此,\(\mathrm x_t\) 中包含了 \(\mathrm y_{t+\Delta t}\) 的信息。

在 DDT 结构下,将 \(\mathrm y\) 更新到 \(\mathrm y_t\) 的过程就变成显式被解耦的了,而结果也表明, DDT 的效果确实要比 DiT 更好, FID 尤其出色。

条件信息反馈回路

现在我们将条件扩散模型重新解释成了编码器、解码器结构,同时我们发现,发展到 DDT, 扩散模型的 \(y_t\) 实际上一直在通过 \(\mathrm x_t\) 的去噪过程,对 \(\mathrm y\) 的条件信息进行更新,从而确保模型可以对 \(\mathrm y\) 提供的信息保真。

但是我们发现,现在对于 \(\mathrm y_t\) 的更新其实局限在了解码器端。

$$p_\theta(\mathrm y_{t}\mid\mathrm y,\mathrm y_{t+\Delta t}, \mathrm x_t, t)$$

这个过程足够好吗?

对于图像重建,编码器是对图像造成损伤的过程, \(\mathrm y\) 也许就是我们能得到的所有信息,自然,我们只能做到这一步。

但是如果我们要做图像编解码传输,这却是不一定的:

我们始终有办法接触到原始图像!

那我们是不是可以进一步思考:

我们现在有了 Encoder 用于从原始图像中提取图像的信息 \(\mathrm y=\mathcal A(\mathrm x_0) + n\),为什么我们要把 \(y\) 一口气传递,而后让模型进行多步扩散,在解码端更新 \(\mathrm y_t\) 呢?

我们可以在编码器端就完成对图像的 \(\mathrm y\) 的更新。

也就是说:

$$\mathrm y_t=\mathcal A(\mathrm x_0, \mathrm x_{t} )+ \mathrm n$$

现在, \(\mathrm y_t\) 所传递的信息变得物理意义不那么明确了。不过我们可以确定的是,它部分表达了从 \(\mathrm x_{t}\) 恢复到 \(\mathrm x_0\) 的速度场。虽然由于 \(\mathrm y_t\) 维度的限制,它可能不能表达完整的速度场,但是至少它可以表达一个分量,确保模型能够向着 \(\mathrm x_0\) 靠近。

最简单地,我们给出一个方便理解的表达,以证明这种方法的可行性:

$$\mathrm y_t=\frac {x_{t}-x_0} {t}[s:s+d]$$

其中 \(d\)\(\mathrm y_t\) 的维度。

那么模型只需要预测一个 \(s\) 并令速度为:

$$\mathrm v_t[i]= \mathrm y_t[i] \text{ if }s<i<s+d \text{ else } 0$$
$$\arg\min_s x_0 - (x_{t}- t v_t)$$

这个学习目标是简单的且有效的。这确保了我们学习的有效性。模型可以根据 \(\mathrm x_t\) 动态更新 \(\mathrm y_t\) 从而了条件信息的准确,能有效地提高重建的准确度。

当然,我们还需要对 \(\mathrm y_t\) 进行量化。

那么假设 \(\mathrm y_t\) 可以在 \([0, 0.25, 0.5, 0.75, 1]\) 中取值。

很显然,我们依然可以传递去噪方向的信息。

当然,这里的问题是,编码器不知道 \(\mathrm x_t\) 的信息,我们也不能将整个 \(\mathrm x_t\) 的信息传回到编码端。

我们只能从解码端将部分信息传回到编码端。

幸好,我们上方计算 \(\mathrm v_t\) 实际上也只使用到了 \(\mathrm x_t\)\(d\) 维的信息,其余的信息都是冗余的。

因此,我们假设模型在解码端完成了对需要更新的维度的选择,在编码器端只进行简单的速度计算,从而传递了 \(\mathrm x_t\) 部分维度的去噪方向。这依然能有效地实现对模型进行去噪。

我们将 \(t\) 时刻解码端传回到编码端的信息表示为 \(\mathrm m_t\).

因此,我们需要构建一个图像重建反馈回路:

编码器端生成出 \(\mathrm y_t\) 传递给解码端,解码端借此更新生成图像 \(\mathrm x_{t-\Delta t}\).

编码器端同时生成出 \(\mathrm m_{t-\Delta t}\) 传递给编码端,编码端借此更新图像条件信息 \(\mathrm y_{t-\Delta t}\).

具体的结构我们可以可以继续使用 1D Tokenizer 进行实现。

Encoder 端与 Decoder 端在图像的 tokens 之上都拼接上 latent tokens, 而后 latent tokens 在 Encoder 和 Decoder 端传递。

具体的,我们的一个去噪过程可以表达为:

$$\left\{\begin{array}{l} &\text{Encoder } \mathcal E:&p_\theta(\mathrm y_t\mid {\mathrm x_0}, \mathrm{m}_t)\\ &\text{Quantizer}\mathcal Q_1:&p_\theta(\mathrm{\tilde y}_t\mid \mathrm{y}_t )\\ &\text{Decoder } \mathcal D:&p_\theta(\hat{\mathrm{x}}_{t-\Delta t}\oplus\mathrm{m}_{t-\Delta t}\mid \hat{\mathrm{x}}_{t}, \mathrm{\tilde y}_t)\\ &\text{Quantizer }\mathcal Q_2:&p_\theta(\mathrm{\tilde m}_{t-\Delta t}\mid \mathrm{m}_{t-\Delta t} ) \end{array} \right.$$

条件信息反馈回路的训练

我们已经构建好了条件信息反馈的回路。下一个问题就是:

我们要如何能实现对这个反馈回路的训练?

对于一个流匹配的扩散模型,它在推理的时间,处理的是

$$p_{\theta,\mathrm x_0\sim p( \mathrm x_0) }(\mathrm x_{t-\Delta t}\mid \mathrm x_t)$$

但是在训练时,模型实际上推理的是

$$p_\theta(\mathrm x_0\mid \mathrm x_t)$$

DDPM 中,模型预测的是时间步为 \(t\) 时的总噪声; Flow Matching 中,模型预测的是时间步为 \(t\) 时,向着原始图像前进的速度场;JiT 中,模型直接预测的就是 \(\mathrm x_0\).

无论是哪一个,训练过程都绕过了模型去噪对上一时刻的依赖,模型学会的是去噪能力,这种能力在噪声量任意时,都是等价的。

我们只需要随机采样时间步 \(t\), 训练模型在时间为 \(t\) 的条件下对模型进行去噪,让模型去噪结果尽可能地接近于原始图像即可。

然而,如果我们要构建起条件信息反馈回路,我们可以看到,我们的编码器始终依赖于上一轮去噪时解码端输出的 \(\mathrm m_t\). 这导致我们不能直接开始训练编码器。

不过,我们注意到,当 \(t=1\) 时,有\(\mathrm m_1\), 它并不是一个解码端的输出,而可以是一个可学习参数。

同时,我们考虑到 \(\mathrm y_t\)\(\mathrm y_1\) 应该是相似的,毕竟二者的传递的都是 \(x_0\) 的部分信息,只不过前者更适用于 \(t\) 时刻的解码过程。但是, \(\mathrm y_t\) 是用来与 \(\mathrm{\hat x}_t\) 搭配,为后者指明去噪方向的。但由于模型去噪得到的 \(\hat{\mathrm x}_t\) 我们同样无法得到,只能用 \(\mathrm x_t\) 来近似。对于它来说 \(\mathrm y_1\) 反而是更合适的选择。

因此,要想得到我们需要的 \(\mathrm m_t\) ,我们需要经过一个比较复杂的过程:

这就完成了一个完整的反向传播过程,并且多了一个超参数:

$$L_\text{total}=w(t_1, t_2)L_1+\left(1-w(t_1, t_2)\right)L_2$$

以重建损失为例,我们有:

$$\hat{\mathrm x}_0=\mathrm{x}_{t_1}-\hat {\mathrm{v}}_{t_1}(t_1 - t_2 )-\hat {\mathrm{v}}_{t_2}t_2$$

而实际上的 \({\mathrm v}_{t_1}=\mathrm v\)\(x_{t_1}\) 是加噪得到的,所以它的方向就是 \(\mathrm x_{1} - \mathrm x_0\) 的方向)却为:

$$\mathrm v_{t_1}=\frac {\mathrm x_{t_1} -\mathrm x_0}{t_1}$$

也即:

$$\mathrm x_0=\mathrm x_{t_1}-\mathrm v_{t_1}t_1$$

两式相减.

$$\hat{\mathrm x}_0-{\mathrm{x}}_0=(\mathrm v_{t_1}-\hat{\mathrm v}_{t_1})(t_1-t_2)+(\mathrm v_{t_1}-\hat{\mathrm v}_{t_2})t_2$$

如果求 \(x\)-loss, 就只需要求解:

$$L_2(\mathrm x_0)=\left\|\mathrm x_0-\mathrm {\hat x}_0\right\|^2_2$$

但这样就无法即约束 \(\mathrm v_{t_1}\) 又约束 \(\mathrm v_{t_2}\), 也没有利用到时间 \(t_2\) 这一信息。

因此我们使用损失:

$$L_2(\mathrm v_1)=\frac {t_1-t_2}{t_1}\left\|\mathrm v_{t_1}-\hat {\mathrm{v}}_{t_1}\right\|^2_2+\frac {t_2}{t_1}\left\|{\mathrm{v}}_{t_1}-\hat {\mathrm{v}}_{t_2}\right\|^2_2$$

借此,我们训练了模型修正去噪方向的能力。

代码实现

我们的模型被称为 progressiveDiTok.

它由以下几个组件构成:

由于存在一些 token 维度问题,我们需要添加以下这些额外的中间层,进行维度的转换:

forward 过程:

解耦退火后验采样

我们现在的扩散过程每一步输出的结果只能将图像从模糊渐渐变清晰,在这个过程中得到的图像其实不太能看。

那么有没有可以,我们得到这样一个过程:每次一步扩散得到的结果都合乎人类感知——看着清晰,但保真度可能一般。但是每一次扩散又能给模型提供新的信息,让模型在真实性上也越来越好,模型从清晰但失真,越来越逼近真实图像的清晰且正确?

可以参考DAPS.

解码端可以边解码,边把图像显示给用户。显示给用户的图像为第二列的 \({\mathrm {\hat x}_0}(\mathrm x_t)\)

DAPS 提供的是一个后验采样框架,也就是说,与我们的模型训练过程无关,我们可以将模型训练完以后,再测试 DAPS 在我们的模型上使用的效果。