0%

解读何恺明团队工作:分形生成其实是一种多叉树视觉 Transformer

最近,备受瞩目的何恺明团队公布了一篇论文——分形生成模型(Fractal Generative Models)。该论文提出了一种叫做分形生成的全新生成范式。以图像分形生成为例,算法会由粗到精地生成每个像素。

但我仔细读过一遍论文后,发现论文的表述不够准确。这篇文章其实提出了一种基于多叉树的效率更高的视觉 Transformer,以降低普通 Transformer 全像素自注意力的高计算复杂度。这种 Transformer 可以用于任何图像生成任务,甚至是图像生成以外的视觉任务。在这篇博文中,我会按我自己的逻辑介绍论文的核心方法,并简单展示论文中的实验结果。之后,我会批判性分析论文的表述,并探讨一些后续可能的科研方向。

知识准备

视觉 Transformer

Transformer 是一种处理序列数据的神经网络。它的核心是自注意力运算。在这个运算中,序列中的元素会两两交换信息。因此,如果序列的长度是 $N$,则自注意力运算的复杂度是 $O(N^2)$。

具体来说,Transformer 的输入和输出的数据是一个形状为 $N \times C$ 的张量。其中,$N$ 为序列的长度,$C$ 为每个数据向量的长度。实际上,自注意力的运算不仅和 $N$ 有关,也和 $C$ 有关。但由于 $C$ 一般是常数,而复杂度分析一般只关注会不断增长的量,所以我们记自注意力运算关于 $N$ 的复杂度是 $O(N^2)$。

图像可以认为是由像素构成的序列。因此,我们可以用 Transformer 处理图像数据,即使用视觉 Transformer (Vision Transformer)。然而,视觉 Transformer 的一大缺陷是计算复杂度过高。假如图像的边长是 $O(N)$,则一次自注意力的计算复杂度高达 $O(N^4)$。

自回归图像生成

自回归(Autoregressive)是一种直观易懂的序列生成范式:在生成第 $i$ 个元素时,生成模型输入前 $i - 1$ 个元素,输出第 $i$ 个元素。以下是文本自回归生成的一个示例:
|步数 | 输入 | 输出 |
|—-|—-|—-|
1 | (空) | 今
2 | 今 | 天
3 | 今天 | 早
4 | 今天早 | 上

如前所述,图像也可以看成一种由像素构成的序列数据。但在自回归生成图像时,我们要给每个像素编号,表示像素生成的先后顺序。只有定义了先后顺序,才能根据前面的像素生成后面的像素。

给图像编号的方式很多。最直接的想法自然是从左到右、从上到下地编号。其他的编号方式也是可行的。以下是 VQGAN 论文 (Taming Transformers for High-Resolution Image Synthesis) 展示的几种像素编号方案。

何恺明团队之前的论文 MAR (Autoregressive Image Generation without Vector Quantization) 表明,可以完全随机地给像素编号。

提升视觉 Transformer 效率

视觉 Transformer 处理图像的效率本身就偏低,再算上多步图像生成带来的计算量,生成一张图像的速度将慢得惨不忍睹。能否加速视觉 Transformer 的计算效率呢?

在早期的 ViT 工作 (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale) 中,图像在正式输入 Transformer 前会做一步叫做「图块化」(patchify)的预处理操作。原来 $256 \times 256$ 大小的图像会被下采样 16 倍,转换成 $16 \times 16$ 个图块。输入的元素数变少了,Transformer 的计算时间自然也就降下来了。

后续工作延续了这种压缩输入元素数的思想,但采用了不同的图像压缩方式。VQVAE, VQGAN, Latent Diffusion Model (LDM) 等论文使用自编码器对图像做近乎无损的压缩,再仅用生成模型来生成压缩图像。后续的 DiT (Diffusion Transfomrer, 来自论文 Scalable Diffusion Models with Transformers) 把这种借助自编码器的压缩方案集成到了基于 Transformer 的扩散模型中。

树是一种常见的数据结构,它以从整体到局部的顺序描述某种事物。树可以表示抽象的生活概念或者具体的计算机概念。比如,通常书籍的结构都是树形,我们可以用第一章、第 1 小节、第 1.1 小节这种逐级递进的结构组织书的内容。

和植物里的树不同,数据结构中的树是从上往下生长的。最上面的节点叫做根节点。每个节点相邻的下层节点叫做子节点,相邻的上层节点叫做父节点。没有子节点的节点叫做叶节点。树本身还是一种满足递归性质的数据结构:我们可以忽略某节点的父节点及父节点连接的所有其他节点,从而将其看作是一个新的树。也就是说,每个节点其实都代表一个以这个节点为根的树。这样的树被称为子树

我们也可以用树来组织一个一维数组中的信息。比如,我们把四个数存在树的叶节点里,并让其余每个节点都维护整个子树里所有数据的某种统计信息(这里我们把统计信息定义为求和)。

用这种树表示数据有什么好处呢?由于节点都维护了整个子树的所有信息,在查询整个树里某种统计信息时,我们不必访问每个叶节点,而是可以直接从某个中间节点中直接返回信息,以提升计算效率。比如,在上面的求和树中,我们要查询第 1-3 个元素的总和。我们不必逐个访问三个叶节点,而是可以去访问更少的节点。严格来说,普通查询区间和的复杂度是 $O(N)$,用这种求和树的话查询的复杂度会降为 $O(logN)$

多叉树 Transformer

Transformer 的平方复杂度

我们先以一维数据为例,学习这篇论文是怎么加速 Transformer 的。在正式学习方法之前,我们再次复习一遍为什么 Tranformer 的运算是平方复杂度的。

忽略 Transformer 中注意力运算的实现细节,我们仅关心每个运算的输入输出是什么。在 Transformer 中,由于自注意力运算是一种全局运算,每个元素的输出都取决于其他所有元素。从信息的交换次数来看,如果序列有 $N$ 个元素,每个元素都要至少访问 $N$ 个输入元素的信息,总计算的复杂度是 $O(N \times N)= O(N^2)$。

用多叉树维护区间信息

在前文对树的知识回顾中,我们知道树的某个根节点可以维护整个子树内所有数据的统计信息。那么,我们可不可以用树来减少 Transformer 的信息交换次数呢?比如,我们要让 8 号元素查询 1-4 号元素的信息,我们不去逐个查询叶子节点,而是去查询它们构成的子树的根。

那么,怎么让 Transformer 具有这样的树形结构呢?首先,我们先想办法让某个节点表示一个区间内的所有数据。回忆一下,在 Transformer 中,数据的输入输出形状都是 $N \times C$。这里,我们先假设 $C=1$,即每个数据都是一个实数。这样,我们可以通过通道拼接的方式把多个数据放在一个节点里,并用形状修改操作实现树节点的「分裂」操作。

如下图所示。一开始,根结点处只有一个数据,数据的长度为 $8$,表示八个实数的拼接。树的每一条边表示 reshape 操作,表示把数据从通道维度 $C$ 上拆开,放到序列长度维度 $N$ 上。最后,数据变成了原本的形状 $8 \times 1$。

接着,我们把多个 Transformer 加入树中。我们只允许同一个节点的所有子节点交换信息,而不与同一级的其他子节点交换信息。

要进行多个独立的计算,可以利用数据的 batch 维度:我们把数据的形状从 $N \times C$ 拓展成 $B \times N \times C$,其中 $B$ 表示 batch 的数量。每个 batch 之间的运算是完全独立的。在下图中,一个蓝框表示一个 batch 的运算,不同框内运算是独立的。每次运算中,序列长度 $N$ 始终等于 $2$,这个 $2$ 表示的是一个树节点的子节点数。$C$ 和之前一样,表示数据向量的长度乘以区间里数据的个数。由于我们默认数据向量长度为 $1$,所以这里的 $C$ 还是表示区间里数据的个数。

一个树的节点如果最多有 $k$ 个子节点,则该树被称为 $k$ 叉树。我们这里默认使用的是二叉树。

这样的结构乍看起来很奇怪:这个树的每个节点表示什么意思?为什么只允许同一个节点的子节点之间交换信息?这样真的能加速吗?我们来一一解释这些问题。

多叉树 Tranformer 的原理解释

首先,我们先忽略数据间的信息交换方式,仅考查每个节点的意义。由于更底层(更深层)的节点经过了更多次 Transformer 运算,所以深层的节点拥有的信息更加准确;与之相对,更浅层的节点的信息更加模糊,但概括性更强,能够描述多个节点的统计信息。如下图所示,a, b, c 三个节点的概括性越来越弱,但信息越来越准确。

这个树和我们之前见过的求和树一样,每一个节点都维护了整个子树的统计信息。然而,对于求和操作,我们可以精确地维护一个子树里所有数据的和;但对于结果无法直接统计的 Transformer 运算,节点的信息越往上越模糊。我们在之后的分析中,必须要考虑浅层节点的信息损失。

然后,我们来考查某一个节点能够看到哪些信息。不妨来考查 8 号节点。如下图所示,我们可以找出 8 号节点在整趟运算中「看过」的节点。它看过了总结 1-4 号节点的节点 a,看过了总结 5-6 号节点的节点 b,还看到了它自己以及相邻的 7 号节点。这样看来,8 号节点确实看到了序列里的每一项数据的统计信息。

然而,如前所述,浅层节点的概括性虽强,它的准确性也越低。因此,每个节点在访问其他节点时,对于越邻近的节点,获取的信息越准确;相距越远的节点,获取的信息越模糊。这种设计其实假设了数据具有局部性

  • 数据的输出几乎只取决于局部信息,越远的数据影响越小。

由于图像数据满足局部性,因此用这种结构处理图像时,浅层节点的信息损失是能够接受的。

最后,我们来计算这套模型的计算复杂度,以验证这种模型能够加速 Transformer。我们仅考虑一个元素的运算。为了获取一个元素的输出,我们要做 $K$ 次 Transformer 运算,其中 $K$ 为树的高度。每一轮 Transformer 运算中,数据的序列长度、通道数都是常数。因此,一轮运算的复杂度是 $O(K)$。如何计算树的高度 K 呢?假设我们用的是二叉树,则每加一层,树能表示的数据就乘 2。因此 $K=log_2N$。最终,一轮运算的复杂度是 $O(logN)$。

当我们要处理 $N$ 项数据时,只需要对复杂度乘一个 $N$。最终,这个新网络结构处理整个序列的复杂度为 $O(NlogN)$,这比普通 Transformer 的 $O(N^2)$ 快多了。

小结

为了减少 Transformer 中的信息交换,我们重新定义元素的信息交换方式:元素不会直接看到其他元素的准确信息,而是会看到其他元素的统计信息。我们要求越远的信息概括性越强,但准确度越低。这样,每个元素都只需要访问 $O(logN)$ 个信息节点。

为了产生这种不同层级的信息节点,我们使用了一种树状 Transformer 结构。越浅的节点经过的 Transformer 块越少,信息越模糊。为了在不增加复杂度的前提下进行信息交换,我们只允许子节点之间进行 Transformer 信息交换。由于子节点数是常数,每轮 Transformer 计算的计算量是固定的。

我们完全可以让树的子节点更多,以提升效率。比如,我们可以把二叉树拓展成四叉树。

图像自回归生成实例

这套多叉树 Transformer 可以被广泛地用到各种视觉任务上。我们以论文中的图像自回归生成为例,简单了解一下这种模型的拓展方式。

从一维数据到二维数据

把多叉树 Transformer 用到图像上时,我们只需要修改元素的编号方式,它从逻辑上等价于一维数据。比如,我们可以用如下的四叉树划分二维空间。

应用到自回归图像生成

相比于一般直接预测结果的图像任务,自回归生成有一个额外的要求:后生成的像素无法看到先生成的像素,需要为 Transformer 的自注意力生成一个描述先后顺序的掩码。因此,在把多叉树应用到自回归图像生成时,我们只需要决定每一处 Transformer 的掩码是什么。

决定掩码,其实就是决定元素的先后顺序。现在,Transformer 的运算仅在局部进行,我们其实只要随意为同一组做 Transformer 运算的数据标号。比如,对于二维四叉树,我们可以用下面的顺序对各组元素标号。

从整体上看,我们并没有修改自回归生成的定义。上述过程只是一种自回归生成的特例而已,它等价于某种全局标号的普通自回归生成。

重要实现细节

在附录中,作者提供了两项提升图像自回归生成效果的实现细节。

指引像素 在生成高分辨率图像时,让浅层的 Transformer 输出的颜色值与后续深层的输出颜色值对齐。当然,由于浅层的图块更大,这实际上是让深层输出颜色值的平均值和浅层输出对齐。

临近图块生成 直接用这套 Transformer 生成图像,会导致图块与图块的边缘不一致。为此,作者修改了 Transformer 的输出,让它除了输出当前图块的结果外,还输出上下左右四个相邻图块的结果。

作者在代码中用了一些相对复杂的逻辑,只读论文难以理解方法实现细节。对细节感兴趣的读者欢迎阅读开源代码库里的 models/mar.py 文件。

实验结果

我们简单看一下本文的图像生成结果。由于这种 Transformer 复杂度较低,作者实现了一个像素级生成模型,而没有按照流行的方法使用两阶段潜空间 (latent space) 图像生成。由于模型直接输出的是每个像素取某一颜色值的概率(类似于 PixelCNN),该模型能够准确建模图像的概率。这种概率可以用 NLL (Negative Log-Likelihood) 指标反映,越低越好。

先看最重要的 ImageNet-256 图像生成任务。作者仅比较了其他像素级生成模型。本文的最好的生成模型 FractalMAR 的 FID 并不是很好。现在主流生成模型的在此任务上的 FID 都小于 2。

再看一下 NLL 指标的结果,作者比较了那些能够准确输出图像概率的模型。从最大似然估计的角度看,本文的方法确实很不错。

从支持的任务上看,本文和其他类别约束的自回归模型一样,支持图像内插/外插,且可以用 ImageNet 类别作为指引。

论文表述批判性分析

在初次读文章时,我无论是看示意图还是看公式、文字,都不能理解算法的意思。反复读了几遍后,我才大概明白作者提出的其实是一种多叉树结构。我最后通过阅读代码验证了我的理解是正确的。我认为论文的部分叙述不够严谨,且对于贡献的总结不够准确。具体的分析如下。

是否是分治算法

对于一些朴素算法为 $O(N^2)$ 的序列处理任务,分治 (divide-and- conquer) 算法通过把任务拆成子任务并递归求解子任务的方法,将总算法的复杂度降到 $O(N^2)$ 以下。比如我们熟悉的快速排序就是一种分治算法。论文中反复将分形自回归称为分治算法,但它与传统意义上的分治算法存在较大的差别。

分治算法的核心是「治」这一步。在子问题解决完了之后,我们要用低于 $O(N^2)$ 的时间合并两个子问题的解,产生一个新的当前问题的解。而对于本文介绍的图像生成任务,或者说本文介绍的多叉树 Transformer 模型,模型的输出是一次性决定好的,不存在返回父节点再次修改这一步,自然也没有什么低于 $O(N^2)$ 的合并算法。

如果用分治算法实现图像生成,那么我们应该在不依赖全局信息的前提下生成局部像素,然后在父节点里根据邻近像素的生成结果,修改当前节点里所有像素的值。本文先决定整体再决定局部的方法恰好是反过来的。

与其称为分治算法,本文的算法更加靠近树形动态规划:一个区间内所有元素的统计信息可以被综合考虑,而不必逐个访问每个元素。

分形自回归是否是一种高效的新生成范式

作者声称以自回归模型为分形生成器的分形生成模型相较以往自回归建模方式计算效率更高。但是,在我看来,提升效率的本质原因是 Transformer 模型设计,而非生成范式上的改进。同时,作者的理论分析和方法设计完全对不上,实际上使用的还是传统自回归生成范式。

原文的有关说明如下:假设每个自回归模型的序列长度是一个可控常量 $k$,并让随机变量的总长度 $N=k^n$,其中 $n=\text{log}_k(N)$ 表示我们的分形框架的递归层数。第一个分形框架的自回归层随后将(自回归的)联合概率分布划分成 $k$ 个子集,每个子集包含 $k^{n-1}$ 个变量。

这里 $k$ 可以理解成多叉树的最大子节点数,即 $k$ 叉树。$n$ 表示树的高度。包含 $k^{n-1}$ 个变量表示一个大图块里包含 $k^{n-1}$ 个像素。

正式地,我们将概率分布分解成

这个式子表示,总的概率分布可以拆成若干个条件概率。每个条件概率的已知事件是之前所有大图块里所有像素的概率,条件事件是当前大图块所有像素的概率。

每个含 $k^{n-1}$ 个变量的条件分布 $p(…|…)$ 接着被第二级递归的自回归模型建模。以此类推。通过递归地调用这种分治过程,我们的分形框架可以通过用 $n$ 级自回归模型来高效率地应对 $k^n$ 个变量的联合分布求解问题,每一级只需要处理长度为 $k$ 的序列,且该序列长度可控。

以上是论文原文。我开始读论文的时候没读懂这段话的意思。在我写这篇博文的时候,发现作者的意思其实可以用前文的示意图表示。由于每一个递归级的每个子问题都只需要处理长度为 $k$ 的序列,且这个 $k$ 是常量,所以总的复杂度是 $O(NlogN)$。

但是,作者这段话不是针对 Transformer 来讲的,而是针对概率分布来讲的。用严谨的话描述,作者的意思可能是说,上面的那个概率分布按递归展开,最后只会有 $O(NlogN)$ 项。因此,总的计算复杂度降低了。

计算概率分布的算法大致可以被称为分治算法。但计算概率分布和执行多叉树神经网络是两码事。

这一段推理存在逻辑漏洞:概率公式有 $O(NlogN)$ 项,并不能代表最后的计算效率是 $O(NlogN)$。这中间还欠缺一个前提条件:每一项条件概率的计算时间是常数级。实际上,一轮自回归生成的计算时间取决于神经网络的设计,而跟自回归的建模方式无关。比如,同样是计算经典自回归条件概率 $P(x_i|x_1, …, x_{i-1})$,如果用 Transformer,它的复杂度就是 $O(i^2)$;如果用 CNN,它的复杂度就是 $O(i)$。

另外,上述条件概率的计算和论文方法设计完全无关。论文完全没有提如何用模型估计各个叶节点的条件概率,然后用分治算法合并概率,用一套新损失函数优化联合概率。论文的方法完全是按照经典自回归的定义,用神经网络建模 $P(x_i|x_1, …, x_{i-1})$,再用交叉熵损失优化最新像素的类别分布。只不过这个神经网络是一个使用了多叉树优化的 Transformer。

综上,我认为作者对于方法的分析和论文贡献的描述有误。作者并没有用到论文中提出的递归概率分布,只是用到了一个多叉树 Transformer。论文的核心贡献是加速 Transformer,而不是一种全新的生成范式。

科研方向探讨

正如论文所展示的,本文提出的多叉树 Transformer 可以用到 AR 和 MAR 两种生成范式上。我们也可以考虑把它拓展到其他生成范式,甚至其他视觉任务上。比如,将其拓展到 VAR (Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction) 上就是一个显而易见的改进方向。 VAR 用同一个 Transformer 处理所有尺度的图像,但这显然不是最优的。结合多叉树 Transformer,我们或许能够让不同层的 Transformer 输出不同尺度下的预测;同时,我们也可以用这种结构提升 VAR 的性能,避免使用全序列注意力。

退一步,从更宏观的角度上看,这篇论文以及 VAR 等论文都是通过利用图像特性来减少计算量。我认为最重要的两个特性有:

  • 局部性:像素受到邻近像素的影响更大,受到远处像素影响更小。
  • 局部连续性:一块像素的颜色、信息是类似的,且变化是平滑的。这使得我们可以用下采样/图块化的大像素块来近乎无损地表示一块像素的统计信息,也允许我们将低尺度信息的线性上采样结果作为当前结果的近似。

基于这两种特性,多数工作采用了如下的优化方案:

  • 降低像素间的依赖关系,较远处的像素与此处的像素可以相互独立地并行生成。
  • 用某种下采样表示一整块像素的信息。
  • 将图像生成建模成从低尺度到高尺度的递进式生成。
  • 使用残差设计,将高尺度图像定义为低尺度图像上采样加上残差图像。

问题的关键就在于,我们应该把哪种优化方案放到算法的哪一步中。这篇论文实际上是把图像的多尺度表示用在了 Transformer 模型上。

实不相瞒,我前段时间也尝试设计了一种使用二维四叉树的自注意力,用于加速所有视觉 Transformer,同时也希望实现一个像素级 Transformer。然而,我发现当下采样的比例超过 2 之后,模型的效果就会出现明显下降。当然,我相信对图像做全注意力是没有必要的,肯定存在着更优的加速方案,这个方向仍有研究价值。

总结

Fractal Generative Models 论文提出了一种用多叉树优化的视觉 Transformer 结构。该 Transformer 将图像表达成多个尺度,浅层信息模糊但概括性强,深层信息准确但概括性弱。每个元素最终仅与 $O(logN)$ 个元素做自注意力,距离越远的元素,使用的信息越浅,访问的节点越少。作者用这个 Transformer 实现了自回归生成任务。特别地,作者实现了一个能够准确计算图像概率的像素级自回归模型。该模型在 ImageNet-64 NLL 指标上超越了以往模型,但在 ImageNet-256 的 FID 指标上不尽如人意。

作者将这套方法称为一种新的生成范式,但我认为作者的表述有误。这种效率更高的视觉 Transformer 可以用在任何视觉任务上,比如 VAR 的生成范式上。我们也可以沿着这篇论文的设计思路,继续思考如何利用图像的自身特性,在不显著降低效果的前提下提升 Transformer 的计算效率。

相关解读文章

解读何恺明团队新作:不用向量离散化的自回归图像生成

NIPS 2024 最佳论文 VAR 深度解读:下一尺度预测为何能超越扩散模型?

欢迎关注我的其它发布渠道