从定量指标上看,Sana 确实好过不少此前的文生图模型。但是,这些指标无法如实反映人类的审美偏好。如前文所述,社区用户认为 Sana 并没有明显好于 SDXL,但指标无法反映这一点。这些指标参考价值不大的另一个证据是 FLUX-dev 和 FLUX-schnell 的比较结果——作为一个被进一步蒸馏出来的模型,schnell 显然是比 dev 的生成质量更差的,但它的 FID, GenEval, DPG 竟然比 dev 还好。因此,在比较文生图模型质量时,我个人建议不要参考文生图的定量指标,而是去参考社区用户的反馈。
另外,虽然 Sana-1.6B 比 FLUX-dev 快了很多,但它比 FLUX-schnell 只快了一倍。或许 Sana 也可以通过蒸馏获得进一步的推理加速。
总结
Sana 是一个以降低运算开销为主要目标的高分辨率文生图模型。它主要通过增加 VAE 压缩比例、使用线性注意力来提升 DiT 的效率。为了弥补线性注意力的能力损失,Sana 将 FFN 改成了 3x3 卷积网络,在提升模型能力的同时免去了位置编码。除了这两大主要设计外,Sana 还通过使用轻量级文本编码器等诸多细节改进提升了模型的训练与推理效率。整体上看,这个工作主要在工程上作出了贡献,线性注意力的设计几乎照搬了之前的工作,没有使用比较新颖的模块设计。
从生成效果上看,尽管 Sana 论文给出的定量指标还不错,但这些指标是否能如实反映文生图质量还存疑。据社区用户反映,Sana 的质量没有明显好于 SDXL。另外,虽然论文一开头就宣称 Sana 能够生成 4096x4096 的图片,但这种图片的细节很差,和 1024x1024 的差不多。这是因为不管是 VAE 还是 DiT 都只在 1024x1024 上训练过。在加大生成尺寸后,图像的清晰程度没有变,只是看起来像素变多了。这篇论文真正应该强调的是生成 4K 图像的速度会更快,而不应该去强调 4K 图像的质量有多高。
从生成速度上来看,Sana 确实比最强开源文生图模型 Flux-Dev 要快很多。但尴尬的是,在 1024x1024 图像生成上,Sana 的速度仅仅是精心蒸馏的 Flux-schnell 的两倍。当然这个对比可能不是很公平,因为 Sana 还没有经过蒸馏。但就目前来看社区用户在生成 1024x1024 的图像时难以感受到 Sana 性能优势。
这篇文章很好地指明了 DiT 的一个发展方向:我们能不能将线性注意力更好地引入 DiT 呢?按我目前的理解,线性注意力就是通过牺牲自注意力全局依赖关系来提升模型计算速度。它的本质和 MAR (Autoregressive Image Generation without Vector Quantization)、VAR (Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction) 很像,都是通过减少计算像素间的依赖关系来提升速度。这个假设在视觉任务中是很合理的:在确定了图像的主要结构后,理解细节只需要局部像素信息。然而,这些加速方法都不可避免地降低模型的能力。在完全不用和完全使用全局信息之间,我们或许要找到一个平衡点,来使 DiT 具有最佳性能和效果。
最近一篇论文因其吸引眼球的标题而刷屏科技自媒体:”The GAN is dead; long live the GAN! A Modern Baseline GAN (GAN 已死?GAN 万岁!一个现代 GAN 基模)”。我不喜欢这种浮夸的标题,因为真正有价值的论文不必靠标题吸引人。带着怨气通读完论文后,我发现这篇论文果然没有做出特别大的创新。
Pycco 这种直观暴力的实现方法让网页开发者能够快速地修改页面生成逻辑。然而,我已经把 HTML 的知识快忘光了,配不上「网页开发者」这个名号。因此,我让 ChatGPT o1 来帮我开发这一功能。
经指导,我认识了 MathJax 这个在网页上渲染公式的工具。只需要在 HTML 的 head 里导入一些包,网页就可以自动识别单行公式 和多行公式 $$$ $$$。我不记得 head 是什么了,但大概能猜到这个是一个相当于声明全局变量的语句块。
我在 pycco_resources\__init__.py 文件里找到了设置 head 的地方。这个文件提供了生成网页的模板,包括了写死的 CSS 文件和部分 HTML 代码。打开这个文件的最快方式是在某处写 import pycco_resources,之后用 IDE 的代码跳转找到这个包在本地的位置。
判断了是否要高亮后,我还需要做对应的修改。我不仅要在 HTML 代码块里高亮代码,还需要把注释块里的特殊命令删掉。通过观察相关代码,我忽然回忆起了 HTML 的部分实现原理:背景高亮就是把一段 HTML 字符串封进一个带有背景高亮样式的标签 <div></div> 里。那剩下的删除注释也很简单,只需要对字符串做一点简单操作就行了。代码修改过程及使用示例如下所示。
在这篇博文中,我会先回顾与 VAR 密切相关的早期工作 VQVAE 和 VQGAN,再介绍论文的方法细节与实验结果,最后分享我对该工作的测试结果与原理探究。在读 VAR 论文时,我发现有个地方的设计存在缺陷。相关实验结果表明, VAR 论文并没有完整地分析出这套方法有效的原因。欢迎大家仔细阅读这一部分并提出自己的思考与见解。
VAR 用了一种更加高级的做法:用残差金字塔来表示这些隐空间特征。我们先来回顾一下拉普拉斯金字塔这一经典图像处理算法。我们知道,图像每次下采样的时候,都会损失一些信息。既然如此,我们可以将一张高分辨率的图像表示为一张低分辨率的图像及其在各个分辨率下采样后的信息损失。如下图所示,最右侧的一列表示拉普拉斯金字塔的输出。
VAR 的方法部分我们看得差不多了,现在来简单看一下实验部分。论文宣称 VAR 在图像生成实验和参数扩增实验上都取得了不错的成果。特别地,VAR 的拟合能力胜过了 DiT,生成速度是 DiT 的 45 倍以上。我们就主要看一下 VAR 在 ImageNet $256 \times 256$ 图像生成上的实验结果。以下是论文中的表格。我同时还附上了何恺明团队的 MAR 工作(Autoregressive Image Generation without Vector Quantization)的实验结果。
先比一下 DiT 和 VAR。先看速度,不管是多大的模型,DiT 的速度都远远慢于 VAR。再看以 FID 为代表的图像拟合指标。VAR 在参数量为 600M 左右时并没有 DiT 效果好。但继续增加参数量后,DiT 的 FID 没有变好的趋势,而 VAR 的 FID 一直在降。最终 VAR 的 FID 甚至超过了 ImageNet 的验证集,可以认为 FID 再低的也意义不大了。
再比一下 MAR 和 VAR。MAR 的刷指标能力更加恐怖,943M 的模型就能有 1.55 的 FID。但根据 MAR 论文,其速度是 DiT-XL 的 5 倍左右,也就是说 VAR 还是比 MAR 快,是 MAR 速度的 9 倍左右。
ImageNet 图像生成已经被各个模型刷到头了。FID 结果能说明 VAR 的拟合能力很强,最起码不逊于 DiT。但在更有挑战性的文生图任务上,VAR 的效果还有待验证。另外,虽然刷指标的时候 DiT 用了 250 步采样,但实际用起来的时候一般就是采样 20 步。如果算上蒸馏的话,采样步数能缩小到 4 步。加上这些加速技巧的话,VAR 不见得会比 DiT 快。
VAR 各尺度生成结果
看完了论文的主要内容,我来分享一下我对 VAR 的一些理论分析与实验结果。
先看一下随机采样结果。我用的是最大的 d=30 的 VAR 模型。在官方采样脚本的默认配置下,两个随机种子 (0, 15) 的输出如下所示。用到的图像类别为火山、灯塔、老鹰、喷泉,每个类别的图各生成了两张。图像的生成速度很快,一秒就生成了全部 8 张图片。
此前,以经典工作 VQGAN 为代表的图像自回归生成模型无论在速度上还是图像质量上都不尽如人意。究其原因,下一个图像词元预测的建模方式既不够合理,也拖慢了生成速度。为此,VAR 提出一种新式自回归策略:将词元图像拆分成多个尺度,通过下一尺度预测实现图像生成。为了兼容这一设计,VAR 对 VQGAN 的自编码器和 Transformer 都进行了修改:自编码器能够将图像编码成多尺度的残差词元图像,而 Transformer 同时输出同一尺度每个词元的独立分布。实验表明,VAR 在 ImageNet 图像生成指标上超越了以 DiT 为代表的扩散模型,且生成速度至少比 DiT 快 45 倍。另外,还有实验表明 VAR 符合扩增定律:增加参数量即可提升模型性能。
我个人认为,和其他前沿生成模型一样,VAR 在 ImageNet 上的表现已经满分了。它能否完成更困难的图像生成认为还有待验证。最近字节跳动发布了 VAR 的文生图版本:Infinity,但这个模型还没有开源。我们可以持续关注 VAR 的各个后续工作。VAR 的生成速度也没有比 DiT 快上那么多,通过减小采样步数,再加上模型蒸馏,DiT 不会比 VAR 慢。当然,VAR 或许也存在进一步加速的可能,只是相关研究暂时没有扩散模型那么多。
VAR 的数学模型是存在缺陷的:词元图的分布不应该等于词元间的独立分布的乘积。最起码论文里没有任何相关分析(用了类似做法的 MAR 论文也没有分析)。通过一些简单的生成实验,我们发现由于 VAR 在其他设计上提升了输出图像的连续性,哪怕同一尺度的词元间是独立采样,甚至是随机均匀采样,模型的输出质量也不会太差。我们需要通过更深入的实验来挖掘 VAR 的生效原理。
我觉得如果一个科研工作能够解释清楚 VAR 中哪些模块起到了最主要的作用,并取其精华,去其糟粕,提出一个更好的生成模型,那这会是一个很不错的工作。我觉得能够探索的方向有:
VAR 的前几个尺度的词元图是最重要的。能不能用更好的方式,比如用扩散模型,来生成前几个尺度的图像,而更大尺度的词元图用一个比 Transformer 更高效的模型来生成。这样模型的质量和效率能进一步提升。
VAR 还是用了 VQ 自编码器。无论怎么样,VQ 操作都会降低模型的重建质量。但另一方面,VQ 也能起到规范解码器输入的作用。究竟我们能不能把 VQ 自编码器换成精度更高的 VAE 呢?换了之后怎么设计多尺度编码呢?
将图像生成拆解成从低分辨率到高分辨率是一种很常见的思想。基于扩散模型,有多种方式来应用这种思想。一种比较直接的方式是显式将图像生成分解成生成最低分辨率的图像和多轮超分辨率,代表工作是 Cascaded Diffusion Models for High Fidelity Image Generation;另一种更加巧妙的方式是将图像上采样和扩散模型的去噪同时进行,代表工作是 f-DM: A Multi-stage Diffusion Model via Progressive Signal Transformation。本文的多尺度设计和 f-DM 非常相似,我会在文末详细分析二者的区别。
和这篇工作非常相关的早期工作是苹果在 2022 发表的 f-DM: A Multi-stage Diffusion Model via Progressive Signal Transformation。f-DM 将扩散模型的加噪推广到了降采样、模糊等其他退化策略上。降采样版的 f-DM 有非常多的设计和本工作很像。苹果该团队次年发表的 Matryoshka Diffusion Models 也用到了按分辨率逐次生成的设计。
将拉普拉斯金字塔融入扩散模型
拉普拉斯金字塔是一种图像表示方法,它把图像按频率成分拆成几张分辨率不同的图像,分辨率越低的图像表示频率越低的图像成分。我们直接通过下面的例子学习它的原理。假如 x 是原图,那么 x(3)=down(down(x)),x(2)=down(x)-up(x(3)), x(1)=x-up(down(x))。对 x(1), x(2), x(3) 求加权和就可以还原输入图像。
Lumina-Next,前沿扩散 Transformer (Diffusion Transformer, DIT) 模型,采用了长度外推技术:Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT (https://arxiv.org/abs/2406.18583)
近几年和 NTK 理论比较相关的论文叫做 Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains。这篇论文用 NTK 理论分析了 NeRF 这种以位置坐标为输入的 MLP 需要位置编码的原因,并将这类位置编码归纳为「傅里叶特征」。
在这篇文章里,我主要基于论文 Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains (后文简称为「傅里叶特征论文」),介绍傅里叶特征这一概念。为了讲清这些理论的发展脉络,我会稍微讲一下 NTK 等理论概念。介绍完傅里叶特征后,我还会讲解它在其他方法中的应用。希望读完本文后,读者能够以这篇论文为基点,建立一个有关位置编码原理的知识网络,以从更深的层次来思考新的科研方向。
这种连续数据有什么好处呢?我们知道,计算机都是以离散的形式来存储数据的。比如,我们会把图像拆成一个个像素,每个像素存在一块内存里。对于图像这种二维数据,计算机的存储空间还勉强够用。而如果想用密集的离散数据表达更复杂的数据,比如 3D 物体,计算机的容量就捉襟见肘了。但如果用一个 MLP 来表达 3D 物体的话,我们只需要存储 MLP 的参数,就能获取 3D 物体在任何位置的信息了。
这就是经典工作神经辐射场 (Neural Radiance Field, NeRF) 的设计初衷。NeRF 用一个 MLP 拟合 3D 物体的属性,其输入输出如下图所示。我们可以用 MLP 学习每个 3D 坐标的每个 2D 视角处的属性(这篇文章用的属性是颜色和密度)。根据这些信息,利用某些渲染算法,我们就能重建完整的 3D 物体。
import torch import torch.nn as nn import torch.nn.functional as F from torchvision.io import read_image, ImageReadMode from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm from einops import rearrange
classFourierFeature(nn.Module): def__init__(self, in_c, out_c, scale): super().__init__() fourier_basis = torch.randn(in_c, out_c // 2) * scale self.register_buffer('_fourier_basis', fourier_basis) defforward(self, x): N, C, H, W = x.shape x = rearrange(x, 'n c h w -> (n h w) c') x = x @ self._fourier_basis x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W) x = 2 * torch.pi * x x = torch.cat([torch.sin(x), torch.cos(x)], dim=1) return x feature_length = 256 model = MLP(feature_length).to(device) fourier_feature = FourierFeature(2, feature_length, 10).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) n_loops = 400 for epoch in tqdm(range(n_loops)): x = fourier_feature(grid) output = model(x) loss = F.l1_loss(output, input_image) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 100 == 0or epoch == n_loops - 1: viz_image(output[0]) print(loss.item()) prev_output = output
根据之前的研究 Random features for large-scale kernel machines 表明,我们不需要密集地采样傅里叶特征,只需要稀疏地采样就行了。具体来说,我们可以从某个分布随机采样 $m$ 个频率 $\mathbf{b_j}$ 来,这样的学习结果和密集采样差不多。当然,根据前面的分析,我们还是令所有系数 $a_j=1$。在实验中,作者发现,$\mathbf{b_j}$ 从哪种分布里采样都无所谓,关键是 $\mathbf{b_j}$ 的采样分布的标准差,因为这个标准差决定了傅里叶特征的带宽,也决定了网络拟合高频信息的能力。实验的结果如下:
我们可以不管图片里 $1/f^x$ 是啥意思,只需要知道 a, b, c 是三组不同的实验就行。虚线是密集采样傅里叶特征的误差,它的结果反映了一个「较好」的误差值。令人惊讶的是,不管从哪种分布里采样 $\mathbf{b_j}$,最后学出来的网络误差都差不多。问题的关键在于采样分布的标准差。把标准差调得够好的话,模型的误差甚至低于密集采样的误差。
classFourierFeature(nn.Module): def__init__(self, in_c, out_c, scale): super().__init__() fourier_basis = torch.randn(in_c, out_c // 2) * scale self.register_buffer('_fourier_basis', fourier_basis) defforward(self, x): N, C, H, W = x.shape x = rearrange(x, 'n c h w -> (n h w) c') x = x @ self._fourier_basis x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W) x = 2 * torch.pi * x x = torch.cat([torch.sin(x), torch.cos(x)], dim=1) return x feature_length = 256 model = MLP(feature_length).to(device) fourier_feature = FourierFeature(2, feature_length, 10).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) n_loops = 400 for epoch in tqdm(range(n_loops)): x = fourier_feature(grid) output = model(x) loss = F.l1_loss(output, input_image) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 100 == 0or epoch == n_loops - 1: viz_image(output[0]) print(loss.item()) prev_output = output
傅里叶特征通过类 FourierFeature 实现。其代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
classFourierFeature(nn.Module): def__init__(self, in_c, out_c, scale): super().__init__() fourier_basis = torch.randn(in_c, out_c // 2) * scale self.register_buffer('_fourier_basis', fourier_basis) defforward(self, x): N, C, H, W = x.shape x = rearrange(x, 'n c h w -> (n h w) c') x = x @ self._fourier_basis x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W) x = 2 * torch.pi * x x = torch.cat([torch.sin(x), torch.cos(x)], dim=1) return x
构造函数里的 fourier_basis 表示随机傅里叶特征的频率,对应论文公式里的$\mathbf{b}$,scale 表示采样的标准差。初始化好了随机频率后,对于输入位置 x,只要按照公式将其投影到长度为 out_c / 2 的向量上,再对向量的每一个分量求 sin, cos 即可。按照之前的分析,我们令所有系数 $a$ 为 $1$,所以不需要对输出向量乘系数。
之后,我们来尝试在傅里叶特征中只用正弦函数。我们将投影矩阵的输出通道数从 out_c / 2 变成 out_c,再在 forward 里只用 sin 而不是同时用 sin, cos。经实验,这样改了后完全不影响重建质量,甚至由于通道数更多了,重建效果更好了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
classFourierFeature(nn.Module): def__init__(self, in_c, out_c, scale): super().__init__() fourier_basis = torch.randn(in_c, out_c) * scale self.register_buffer('_fourier_basis', fourier_basis) defforward(self, x): N, C, H, W = x.shape x = rearrange(x, 'n c h w -> (n h w) c') x = x @ self._fourier_basis x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W) x = 2 * torch.pi * x x = torch.sin(x) return x
只用 sin 而不是同时用 sin, cos 后,似乎我们之前对 NTK 平移不变的推导完全失效了。但是,根据三角函数的周期性可知,只要是把输入映射到三角函数上后,网络主要是从位置间的相对关系学东西。绝对位置对网络来说没有那么重要,不同的绝对位置只是让所有三角函数差了一个相位而已。只用 sin 的神经网络似乎也对绝对位置不敏感。为了证明这一点,我把原来位于 [0, 1] 间的坐标做了一个幅度为 10 的平移。结果网络的误差几乎没变。
1 2 3 4 5 6 7
for epoch in tqdm(range(n_loops)): x = fourier_feature(grid + 10) output = model2(x) loss = F.l1_loss(output, input_image) optimizer.zero_grad() loss.backward() optimizer.step()