ESC
输入关键词搜索文章
目录

ProgressiveDiTok v0.3-从Next Token Prediction到 Next Scale Prediction

相关论文

VAR

扩散模型与自回归模型

自回归模型

我们知道现在 LLM 领域的主要生成范式就是自回归模型。自回归模型的一个显著特点是其因果性,输出序列中某一处的输出依赖于其之前的输出,但并不依赖于其之后的输出。

自回归的生成范式在图像生成领域也有应用。这些模型一般都先将图像进行 Patchify, 得到二维的 patches. 而后,将这些 patches 展开成一维的序列。在生成时,以逐个 token 生成的方式生成 tokens, 再从 token 序列重建出图像。

但这里显然有一些别扭:图像的 tokens 之间的关系是一种空间信息相互依赖,而不是 NLP 领域一维词序列的因果依赖。当前需要输出的图像 token 显然不仅依赖于当前的 token, 也依赖于其后的 token. 当需要生成的图像比较大, token 序列比较长的时候,之前生成的 token 中出现的错误会被积累,以至于图像生成失败。

在视觉模型方面,自回归模型却没有扩散模型更流行。至今为止,视觉生成领域的主流范式依然是扩散模型。

扩散模型

我们知道扩散模型的特点是,在较早的扩散步中,扩散模型主要生成图像的低频特征,也就是图像的语义特征,随着扩散步的增加,再逐渐生成图像的细节。这种范式与人类的视觉认知过程是相符的:从整体到局部,从语义到细节。

在以 DiT 为基础的扩散模型中,扩散模型与自回归模型一样,都需要将图像切分成 tokens. 但是在扩散模型的生成范式中,图像的 tokens 是同时被进行修改的。这确保了图像 tokens 的空间依赖被充分利用。

扩散模型与自回归模型

我们可以思考一下:扩散模型每一次进行扩散时,都会在上一轮扩散输出的图像的基础上添加新的特征。那么这些新生成的特征是不是也可以形成一个序列呢?

显然是可以的。而且我们知道,这个特征序列是从特频向高频前进。

既然我们可以把扩散模型每一轮新增的特征排成序列,那么我们就可以想一个办法,使用自回归模型来生成这些特征,从而使用自回归模型实现与扩散模型同样的效果。

基于Next-Token 预测的 AR 模型

首先,要想引入本文的方法,我们需要先回顾一下什么是 Next-Token Prediction.

这个方法最早用于 NLP 领域,比较经典的就是 GPT 系列。其过程可以表示为如下的公式:

$$p\left(x_1, x_2, \ldots, x_T\right)=\prod_{t=1}^{T} p\left(x_t\mid x_1, x_2, \ldots, x_{t-1}\right).$$

也就是说,模型根据已经生成得到的 tokens, 去预测下一个 token 的值。可以看到,这种范式中有一种因果性。模型无法得到下一个 token 的内容。

要训练一个 AR 模型,就需要训练模型得到 \(p\left(x_t\mid x_1, x_2, \ldots, x_{t-1}\right).\)的能力。

Tokenization

tokenization的公式如下: 给定原始图像 \(im\),编码器 \(\mathcal{E}(\cdot)\) 将其转换为特征图 \(f \in \mathbb{R}^{h \times w \times C}\)

$$f = \mathcal{E}(im)$$

量化器 \(\mathcal{Q}(\cdot)\) 使用一个可学习的码本 \(Z \in \mathbb{R}^{V \times C}\)(包含 \(V\) 个向量)将特征向量 \(f^{(i,j)}\) 量化为离散的令牌索引 $q^{(i,j)} ∈ [V]$

$$q^{(i, j)} = \left( \underset{v \in [V]}{\arg\min} \| \text{lookup}(Z, v) - f^{(i, j)} \|_2 \right)$$

其中 \(\text{lookup}(Z, v)\) 表示从码本 \(Z\) 中取出第 \(v\) 个向量。

Reconstruction

给定量化后的令牌索引 \(q\),通过查找码本 \(Z\) 得到近似特征图 \(\hat{f}\)

$$\hat{f} = \text{lookup}(Z, q)$$

然后使用解码器 \(\mathcal{D}(\cdot)\) 重建图像 \(\hat{im}\)

$$\hat{im} = \mathcal{D}(\hat{f})$$

训练量化自编码器的总损失函数为:

$$\mathcal{L} = \| im - \hat{im} \|_2 + \| f - \hat{f} \|_2 + \lambda_{\mathrm{P}} \mathcal{L}_{\mathrm{P}}(\hat{im}) + \lambda_{\mathrm{G}} \mathcal{L}_{\mathrm{G}}(\hat{im})$$

其中:

讨论

在 NLP 领域,Next-token Prediction 的逻辑是很顺畅的。

人类输出的句子本身就具有因果性,词与词之间的有严格的逻辑上的因果顺序,因此模型像人一样,基于已经说出的内容,继续输出后续的内容,就很合理了。

但是在视觉领域,这中间就出现了些许别扭:图像的 tokens 之间是相互依赖的,一个 token 中的信息既依赖于其前的 token, 也依赖于其后的 token. token 与 token 之间并没有逻辑上的先后顺序,而是一种空间上的依赖。

我们将 2d 的图像 token 通过 rasterization 转成了 1d 的形式,这并没有把空间依赖转变成因果依赖。这是本文想要解决的一个核心问题。

同时,图片的中央,信息量一般比较大;图像的四周,信息量一般比较小。图像的 patch 与 patch 之间具有空间依赖。

因此,从左到右、从上到下,把图片的每一个部分拆成同样信息量的 token, 这也是不合理的,没有实现信息量的最优分配。

我们需要想到一个办法来利用上这种信息的不均匀性以及图像信息的全局依赖性。

传统方法的四大弊端

  1. 数学前提违背:VQVAE编码器产生的特征图本身是双向依赖的,将其展平为序列后,令牌间依然存在双向关联,这与自回归模型“当前令牌只依赖于前缀”的单向依赖假设相矛盾。
  2. 零样本泛化能力受限:由于单向建模,模型无法完成需要双向推理的任务,例如给定图像下半部分预测上半部分。
  3. 结构退化:将二维空间网格展平为一维序列,破坏了图像固有的空间局部性和邻近令牌间的强相关性。
  4. 效率低下:生成一张 n x n 令牌的图像需要 O(n²) 步自回归迭代,总计算复杂度高达 O(n⁶),导致推理速度极慢。

基于下一尺度预测的 AR 模型

VAR 模型提出了下一尺度预测的算法,有效地解决了传统的 Next Token Prediction 在视觉领域水土不服的问题。

方法描述

接上文,我们已经说到了 rasterization 以后得到的 token 序列之间并没有形成因果关系。要想解决这种“别扭”,我们就需要重新建模我们的方法,让我们的 AR 过程从利用空间依赖转变为利用图像中的因果关系。

那么图像中的因果关系在哪里呢?

首先我们思考:当我们看到一张图像的时候,我们是如何对其形成认知的?

是的,在视觉领域,我们认知一张图片,都是从图片的整体到局部、从语义到细节。

本文提出的想法就是,将图像的 next-token prediction 替换成 next-scale prediction. 也就是说,从图像的低分辨率开始,向高分辨率漂移,向模糊的图像中逐步填充细节,完成图像的生成。

该过程如下:

给定一个特征图 \(f \in \mathbb{R}^{h \times w \times C}\),将其量化为 \(K\) 个多尺度令牌图 \((r_1, r_2, \dots, r_K)\),每个图的分辨率 \(h_k \times w_k\) 逐级提高,最终 \(r_K\) 与原始特征图分辨率 \(h \times w\) 匹配。视觉自回归建模(VAR)的似然函数定义为:

$$p\left(r_1, r_2,\ldots, r_K\right)=\prod_{k=1}^K p\left(r_k\mid r_1, r_2,\ldots, r_{k-1}\right),$$
其中每个自回归单元 \(r_k \in [V]^{h_k \times w_k}\) 是尺度 \(k\) 的令牌图,包含 \(h_k \times w_k\) 个令牌,序列 \((r_1, r_2, \dots, r_{k-1})\)\(r_k\) 的“前缀”。在第 \(k\) 个自回归步骤中,\(r_k\) 中的所有 \(h_k \times w_k\) 个令牌的分布将基于其前缀和对应的第 \(k\) 个位置嵌入图并行生成。

以下是伪代码:

这种思路有些类似于泰勒展开。先用低频特征去拟合原始特征,再用高频特征去补充细节,最终逼近真实特征。

优点

基于这种方法,本文提出的方法成功解决了前面 Next-token Predcition 所遇到的问题:

首先,Next-scale Prediction 通过设计一个多尺度量化自编码器,确保每个尺度的令牌图 \(r_k\) 的生成仅依赖于其前缀 \((r_1, r_2, \dots, r_{k-1})\)。这符合自回归模型“当前单元只依赖于前缀”的单向依赖假设,解决了传统方法中因VQVAE编码器产生双向依赖特征图而导致的数学前提违背问题。

其次,Next-scale Prediction 也保留了图像的空间局部性。它没有使用展平操作,VAR直接处理二维令牌图。同时,尺度内完全相关,每个尺度 \(r_k\) 内的所有令牌是并行生成且完全相关的。并且,Next-scale Prediction 实现了多尺度设计强化结构:多尺度(从粗到细)的设计本身进一步强化了图像的空间结构。

第三,Next-scale Prediction 也有利于提升生成效率。传统AR生成 \(n \times n\) 图像需要 \(O(n^2)\) 步迭代,总计算复杂度为 \(O(n^6)\)。VAR生成同样图像仅需 \(O(\log(n))\) 步迭代(即尺度数量),总计算复杂度降低至 \(O(n^4)\)。效率提升源于每个尺度内的所有令牌是并行预测的。

这种方法也同样有利于实现零样本任务泛化,VAR模型在未经微调的情况下,能够完成需要双向推理的任务,如图像修复(in-painting)、外绘(out-painting)和编辑。这是因为在推理时,可以灵活地“教师强制”(teacher-force)已知区域的令牌(作为条件),并让模型自回归地生成未知区域的令牌,从而实现了传统单向AR模型无法完成的零样本泛化。

Next-scale Prediction 与渐近式压缩编解码

事实上, Next-Scale Prediction 已经实现了图像的渐近式压缩编解码。

对于一张图像而言,使用 Next-Scale Prediction 的方法,在编码端对图像的不同尺度进行了信息提取以后,可以立即将其结果 tokenization 之后传输到解码端。解码端进行解码以后,立即可以完成对图像的进一步重建。

但是, VAR 依然局限于图像本身的空间结构,每一个 token 本身局限于空间结构,无法真正地提取图像的全局信息。

基于下一描述 token 的 AR 模型

我们的方法基于 1D-Tokenizer ,打破了图像空间结构的束缚,可以让模型充分提取图像中的信息。

对于编码端,我们有两种选择:

  1. 始终保留上一轮输出的结果 latent tokens (这些 latent tokens 中已经包含前面轮次提取出的低频特征,新增 tokens 专门提取高频特征)。
  2. 每一轮要预测的结果为下采样图像减去之前所有轮的下采样图像的上采样结果(预测图像残差特征,也即新增特征)。

其中,第一种方法可能更适合。在 VAR 里面, VectorQuantizer 让每一个 token 匹配残差信息,但是残差减去的也是拟合特征。

对于这种模型,我们可以这样进行类比理解:

在 NLP 领域,我们会给定一段话,让模型对已有的话进行补全;而在图像中,我们也可以让模型以 token 的形式,对图像的内容进行补全。

方法描述

Encoding

给定一张图像,我们将其缩放成 \(K\) 个多尺度的图像 \((im_1, im_2,\cdots, im_K)\). 其中 \(im_K=im, \text{shape}(im_k)= (h_k, w_k)\).

每一步新增 latent tokens 为 \(l_k\)

\begin{array}{l} \text{Encoder } \mathcal E,\text{Quantizer } \mathcal Q;\\ \text{Inputs: token map } L = [], \text{image } im;\\ \text{for } k = 1, ..., K \text{ do}\\ \quad l_k = \mathrm 0_{\text{len}_k};\\ \quad im_k=\text{interpolate}(im, h_k, w_k);\\ \quad l_k = \mathcal Q(\mathcal E(im_k, *L, l_k));\\ \quad L = \text{queue_push}(L, l_k);\\ \text{Return: latent token sequence } L\\ \end{array}

Reconstruction

在图像重建阶段,每一轮都会得到新的 latent tokens 为 \(l_k\)

\begin{array}{l} \text{Decoder } \mathcal D; \\ \text{Inputs: token map } L;\\ \hat {im}=0;\\ \text{for } k = 1, ..., K \text{ do}\\ \quad l_k = \text{Queue_pop}(L) ;\\ \quad \hat {im} = \mathcal D(\hat {im}, l_{0:k});\\ \text{Return: Reconstructed image } \hat {im}\\ \end{array}