0%

在上一篇文章中,我们梳理了基于自编码器(AE)的图像生成模型的发展脉络,并引出了Stable Diffusion的核心思想。简单来说,Stable Diffusion是一个两阶段的图像生成模型,它先用一个AE压缩图像,再在压缩图像所在的隐空间上用DDPM生成图像。在这篇文章中,我们来精读Stable Diffusion的论文:High-Resolution Image Synthesis with Latent Diffusion Models

注意:如果你从未学习过扩散模型,Stable Diffusion并不是你应该的读的第一篇论文。请参照我[上篇文章]的早期工作总结,至少在学会了DDPM后再来学习Stable Diffusion。

摘要与引言

论文摘要的大意如下:扩散模型的生成效果很好,但是,在像素空间上训练和推理扩散模型的计算开销都很大。为了在不降低质量与易用性的前提下用较少的计算资源训练扩散模型,我们在一个预训练过的自编码器的隐空间上使用扩散模型。相较以往的工作,在这种表示下训练扩散模型首次在减少计算复杂度和维持图像细节间达到几近最优的平衡点,极大地提升了视觉保真度。通过向模型架构中引入交叉注意力层,我们把扩散模型变成了强大而灵活的带约束图像生成器,它支持常见的约束,如文字、边界框,且能够以纯卷积方式实现高分辨率的图像合成。我们的隐扩散模型(latent diffusion model, LDM) 在使用比像素扩散模型少得多的计算资源的前提下,在各项图像合成任务上取得最优成果或顶尖成果。

整理一下。论文提出了一种叫LDM的图像生成模型。论文想解决的问题是减少像素空间扩散模型的运算开销。为此,LDM借助了VQVAE「先压缩、再生成」的想法,把扩散模型用在AE的隐空间上,在几乎不降低生成质量的前提下减少了计算量。另外,LDM还支持带约束图像合成及纯卷积图像超分辨率。

在上一篇回顾LDM早期工作的文章中,我们已经理解了LDM想解决的问题及解决问题的思路。因此,在读完摘要后,我们接下来读文章时只需要关注LDM的两个创新点:

  1. LDM的AE是怎么设计以达到压缩比例与质量的平衡的。
  2. LDM怎么实现带约束的图像合成。

引言基本是摘要的扩写。首先,引言大致介绍了图像合成任务的背景,提及了扩散模型近期的突出表现。随后,引言介绍了本文想解决的主要问题:扩散模型的训练和推理太耗时了,需要在不降低效果的前提下减少扩散模型的运算量。最后,引言揭示了本工作的解决方法:使用类似VQGAN的两阶段图像生成方法。

引言的前两部分没有什么关键信息,而最后一部分介绍了本工作改进扩散模型的动机,值得一读。如下图所示,DDPM的论文展示了从不同去噪时刻的同一个噪声图像开始的不同生成结果,比如$\mathbf{x}_{750}$指从时刻$t=750$的去噪图像开始,多次以不同随机数执行DDPM的反向过程,生成的多幅图像。LDM作者认为,DDPM的这一实验表明,扩散模型的图像生成分两个阶段:先是对语义进行压缩,再是对图像的感知细节压缩。正因此,随机对早期的噪声图像去噪,生成图像的内容会更多样;而随机对后期的噪声图像去噪,生成图像只是在细节上有所不同。LDM的作者认为,扩散模型的大量计算都浪费在了生成整幅图像的细节上,不如只让扩散模型描述比较关键的语义压缩部分,而让自编码器(AE)负责感知细节压缩部分。

引言在结尾总结了本工作的贡献:

  1. 相比之前按序列处理图像的纯Transformer的方法,扩散模型能更好地处理二维数据。因此,LDM生成隐空间图像时不需要那么重的压缩比例(比如DIV2K数据集上,LDM只需要将图像下采样4倍,而之前的纯Transformer方法要下采样8倍或16倍),图像在压缩时能有更高的保真度,整套方法能更高效地生成高分辨率图像。
  2. 在大幅降低计算开销的前提下在多项图像生成任务上取得了顶尖成果。
  3. 相比于之前同时训练图像压缩模型和图像生成模型的方法,该方法分步训练两个模型,训练起来更加简单。
  4. 对于有着稠密约束的任务(如超分辨率、补全、语义生成),该方法的模型能换成一个纯卷积版本的,且能生成边长为1024的图像。
  5. 该工作设计了一种通用的约束机制,该机制基于交叉注意力,支持多模态训练。作者训练了多种带约束的模型。
  6. 作者把工作开源了,并提供了预训练模型。

我们来整理一下这些贡献。读论文时,可以忽略第6条。第2条是成果,与方法设计无关。第1、3条主要描述了提出两阶段图像生成建模方法的贡献。第4条是把方法拓展到稠密约束任务的贡献。第5条是提出了新约束机制的贡献。所以,在学习论文的方法时,我们还是主要关注摘要里就提过的那两个创新点。在读完引言后,我们可以把阅读目标再细化一下:

  1. LDM的AE是怎么设计以达到压缩比例与质量的平衡的。与纯基于Transformer的VQGAN相比,它有什么不同。
  2. LDM怎么用交叉注意力机制实现带约束的图像生成。

相关工作

作者主要从两个角度回顾了早期工作:不同架构的图像生成模型与两阶段的图像合成方法。其回顾逻辑与本系列的第一篇文章类似,在此就不过多介绍了。除了介绍早期工作外,作者重申了引言中的对比结果,强调了LDM相对于扩散模型的创新和相对于两阶段图像生成模型的创新。

方法

在方法章节中,作者先是大致介绍了使用LDM这种两阶段图像生成架构的优点,再分三部分详细介绍了论文的实现细节:图像压缩AE的实现、LDM的实现、约束的实现。开头的介绍和AE的实现相对比较重要,我们放在一起详细阅读;相对于DDPM,LDM几乎没有做任何修改,只是把要拟合的图片从真实图片换成了压缩图片,这一部分我们会快速浏览一遍;而添加约束的方法有所创新,我们会详细阅读一遍。

AE与两阶段图像生成模型

我们来先读3.1节,看一看AE的具体实现方法,再回头读第3节开头介绍的两阶段图像生成模型的优点。

LDM配套的图像压缩模型(论文中称之为”感知压缩模型”)和VQGAN几乎完全一样。该压缩模型的原型是一个AE。普通的AE会用原图像和重建图像的重建误差(L1误差或者L2误差)来训练。在普通的AE的基础上,该压缩模型参考了GAN的误差设置方法,使用感知误差代替重建误差,并添加了基于patch的对抗误差。

但该图像压缩模型的输出与VQGAN有所不同。我们先回忆一下VQGAN的原理。VQGAN的输出会接到Transformer里,Transformer的输入必须是离散的。因此,VQGAN必须要额外完成两件事:1)让连续输出变成离散输出;2)用正则化方法防止过拟合。为此,VQGAN使用了VQVAE里的向量离散化操作,该操作能同时完成这两件事。

而LDM的压缩模型的输出会接入一个扩散模型里,扩散模型的输入是连续的。因此,LDM的压缩模型只需要额外完成使用正则化方法这一件事。该压缩模型不必像VQGAN一样非得用向量离散化来完成正则化。如我们在第一篇文章中讨论的,作者在LDM的压缩模型中使用了两种正则化方法:VQ正则化与KL正则化。前者来自于VQVAE,后者来自于VAE。

该压缩模型相较VQGAN有一项明显的优势。VQGAN的Transformer只能按一维序列来处理图像(通过把二维图像reshape成一维),且只能处理较小的压缩图像($16\times16$)。而本身用于二维图像生成的LDM能更好地利用二维信息,因此可以处理更大的压缩图像($64\times 64$)。这样,LDM的压缩模型的压缩程度不必那么重,其保真度会比VQGAN高。

看完了3.1节,我们来回头看第3节开头介绍了LDM的三项优点:1)通过规避在高维图像空间上训练扩散模型,作者开发出了一个因在低维空间上采样而计算效率大幅提升的扩散模型;2)作者发掘了扩散模型中来自U-Net架构的归纳偏置(inductive bias),使得它们能高效地处理有空间结构的数据(比如二维图像),避免像之前基于Transformer的方法一样使用激进、有损质量的压缩比例;3)本工作的压缩模型是通用的,它的隐空间能用来训练多种图像生成模型。第一个优点是相对于DDPM。第二个是优点是相对于使用Transformer的VQGAN,我们在上一段已经分析过了。第三个优点是相对于之前那些换一个任务就需要换一个压缩模型的两阶段图像生成模型。

归纳偏置可以简单理解为某个学习算法对一类数据的优势。比如CNN结构适合处理图像数据。

隐扩散模型(LDM)

在DDPM中,一个参数为$\theta$的神经网络$\epsilon_\theta$会根据当前时刻$t$的带噪图片$x_t$预测本时刻的噪声$\epsilon_\theta(x_t, t)$。网络的学习目标是让预测的噪声和真实的噪声$\epsilon$一致。

LDM的原理和DDPM完全一样,只不过训练图片从像素空间上的真实图片$x_0$变成了隐空间上的压缩图片$z_0$,每一轮的带噪图片由$x_t$变成了隐空间上的带噪图片$z_t$。在训练时,相比DDPM,只需要多对$x_0$用一次编码器变成$z_0$即可。

如果你在理解这部分内容时有疑问,请去阅读DDPM的相关文章。LDM的具体结构我们会在第三篇代码阅读文章中讨论。

约束机制

让模型支持带约束图像生成,其实就是想办法把额外的约束信息输入进扩散模型中。显然,最简单的添加约束的方法就是把额外的信息和扩散模型原本的输入$z_t$拼接起来。如果约束是一个值,就把相同的值拼接到$z_t$的每一个像素上;如果约束本身具有空间结构(如语音分割图片),就可以把约束重采样至和$z_t$一样的分辨率,再逐像素拼接。除了直接的拼接外,作者在LDM中还使用了另一种融合约束信息的方法。

DDPM中含有自注意力层。自注意力操作其实基于注意力操作$Attention(Q, K, V)$,它可以解释成一个数据库中存储了许多数据$V$,数据的索引(键)是$K$,现在要用查询$Q$查询数据库里的数据并返回查询结果。注意力操作有几种用法,第一种用法是交叉注意力$CrossAttn(A, B)=Attention(W_Q \cdot A, W_K \cdot B, W_V \cdot B)$,可以理解成数据$A$, $B$做了一次信息融合;第二种用法是自注意力$SelfAttn(A)=Attention(W_Q \cdot A, W_K \cdot A, W_V \cdot A)$,可以理解成数据$A$自己做了一次特征提取。

既然交叉注意力操作可以融合两类信息,何不把DDPM的自注意力层换成交叉注意力层,把$K$, $V$换成来自约束的信息,以实现带约束图像生成呢?如下图所示,通过把用编码器$\tau_\theta$编码过的约束信息输入进扩散模型交叉注意力层的$K$, $V$,LDM实现了带约束图像生成。这里的实现细节我们会在第三篇代码阅读文章中讨论。

根据论文中实验的设计,对于作用于全局的约束,如文本描述,使用交叉注意力较好;对于有空间信息的约束,如语义分割图片,则用拼接的方式较好。

实验

在这一章里,作者按照介绍方法的顺序,依次探究了图像压缩模型、无约束图像生成、带约束图像合成的实验结果。我们主要关心前两部分的实验结果。

感知压缩程度的折衷

论文首先讨论了图像压缩模型在不同的下采样比例$f$下的实验结果,其中$f\in\{1, 2, 4, 8, 16, 32\}$。这些实验分两部分,第一部分是训练速度上的实验,第二部分是采样速度与效果上的实验。

在ImageNet上以不同下采样比例$f$训练一定步数后LDM的采样指标对比结果如下图所示。其中,FID指标越低越好,Inception Score越高越好。结果显示,无论下采样比例$f$是过大还是过小都会降低训练速度。作者分析了$f$较小或较大时训练速度慢的原因:$f$过小时扩散模型把过多的精力放在了本应由压缩模型负责的感知压缩上;$f$过大时图像信息在压缩中损失过多。LDM-$\{4\text{-}16\}$的表现相对好一些。

在实验的第二部分中,作者比较了不同采样比例$f$的LDM在CelebA-HQ(下图左侧)和ImageNet(下图右侧)上的采样速度和采样效果。下图中,横坐标为吞吐量,越靠右表示采样速度越快。同一个模型的不同实验结果表示使用不同DDIM采样步数时的实验结果,每一条线上的结果从右到左分别是DDIM采样步数取$\{10, 20, 50, 100, 200\}$的采样结果(DDIM步数越少,采样速度越快,生成图片质量越低)。对于CelebA-HQ上的实验,若采样步数较多,则还是LDM-$\{4, 8\}$效果较好,只有在采样步数较少时压缩比更高的LDM才有优势。而对于ImageNet上的实验,$f$太小或太大的结果都很差,整体上还是LDM-$\{4, 8\}$的结果较好。

综上,根据实验,作者认为$f$取适中的$4$或$8$比较妥当。下采样比例$f=8$也正是Stable Diffusion采用的配置。

图像生成效果

在这一节中,作者在几个常见的数据集上对比了LDM与其他模型的无约束图像生成效果。作者主要比较了两类指标:表示采样质量的FID和表示数据分布覆盖率的精确率及召回率(Precision-and-Recall)。

在介绍具体结果之前,先对这个不太常见的精确率及召回率指标做一个解释。精确率及召回率常用于分类等有确定答案的任务中,分别表示所有被分类为正的样本中有多少是分对了的、所有真值为正的样本中有多少是被成功分类成正的。而无约束图像生成中的精确率及召回率的解释可以参加论文Improved Precision and Recall Metric for Assessing
Generative Models
。如下图所示,设真实分布为蓝色,生成模型的分布为红色,则红色样本落在蓝色分布的比例为精确率,蓝色样本落在红色分布的比例为召回率。简单来说,精确率能描述采样质量,召回率能描述生成分布与真实分布的覆盖情况。

接下来,我们回头来看论文展示的无约束图像生成对比结果,如下图所示。整体上看,LDM的表现还不错。虽然在FID指标上无法超过GAN或其他扩散模型,但是在精确率和召回率上还是颇具优势。唯一没有被LDM战胜的是LSUN-Bedrooms上的ADM模型,但作者提到,相比ADM,LDM只用了一半的参数,且只需四分之一的训练资源。

带约束图像合成

这一节里,作者展示了LDM的文生图能力。论文中的LDM用了一个从头训练的基于Transformer的文本编码器,与后续使用CLIP的Stable Diffusion差别较大。这一部分的结果没那么重要,大致看一看就好。

本文的文生图模型是一个在LAION-400M数据集上训练的KL约束LDM。它的文本编码器是一个Transformer,编码后的特征会以交叉注意力的形式传入LDM。采样时,LDM使用了Classifier-Free Guidance。

Classifier-Free Guidance可以让输出图片更符合文本约束。这是一种适用于所有扩散模型的采样策略,并非要和LDM绑定,感兴趣可以去阅读相关论文。

LDM与其他模型的文生图效果对比如下图所示。虽然这个版本的LDM并没有显著优于其他模型,但它的参数量是最少的。

LDM在类别约束的图像合成上表现也很不错,超越了当时的其他模型。其结果在此略过。

剩余的带约束图像合成任务都可以看成是图像转图像任务,比如图像超分辨率是低质量图像到高质量图像的转换、语义生成是把语义分割图像转换成一幅合成图像。要添加这些约束,只需要把这些任务的输入图片和LDM原本的输入$z_t$拼接起来即可。比如对于图像超分辨率,可以把输入图片直接与隐空间图片$z_t$拼接,解码后图片会被自然上采样$f$倍;对于语义生成,可以把下采样$f$倍的语义分割图与$z_t$拼接。论文用这些任务上的实验证明了LDM的泛用性。由于这部分实验与LDM的主要知识无关,具体实验结果就不在此详细介绍了。

总结

论文末尾探讨了LDM的两大不足。首先,尽管LDM的计算需求比其他像素空间上的扩散模型要少得多,但受制于扩散模型本身的串行采样,它的采样速度还是比GAN慢上许多。其次,LDM使用了一个自编码器来压缩图像,重建图像带来的精度损失会成为某些需要精准像素值的任务的性能瓶颈。

论文最后再次总结了此方法的贡献。LDM的主要贡献其实只有两点:在不损失效果的情况下用两阶段的图像生成方法大幅提升了训练和采样效率、借助交叉注意力实现了各任务通用的约束机制。这两个贡献总结得非常精准。之后的Stable Diffusion之所以大受欢迎,第一就是因为它采样所需的计算资源不多,大众能使用消费级显卡完成图像生成,第二就是因为它强大的文字转图片生成效果。

我们再从知识学习的角度总结一下LDM。LDM的核心知识是DDPM和VQGAN。如果你能看懂之前这两篇论文,那你一下子就能明白LDM是的核心思想是什么,看论文时只需要精读交叉注意力约束机制那一段即可,其他实验内容在现在看来已经价值不大了。由于近两年有大量基于Stable Diffusion开发的工作,相比论文,阅读源代码的重要性会大很多。我们会在下一篇文章里详细学习Stable Diffusion的官方源码和最常用的Stable Diffusion第三方实现——Diffusers框架。

在2022年的这波AI绘画浪潮中,Stable Diffusion无疑是最受欢迎的图像生成模型。究其原因,第一,Stable Diffusion通过压缩图像尺寸显著提升了扩散模型的运行效率,使得每个用户能在自己的商业级显卡上运行模型;第二,有许多基于Stable Diffusion的应用,比如Stable Diffusion自带的文生图、图像补全,以及ControlNet、LoRA、DreamBooth等插件式应用;第三,得益于前两点,Stable Diffusion已经形成了一个庞大的用户社群,大家互相分享模型,交流心得。

不仅是大众,Stable Diffusion也吸引了大量科研人员,很多本来研究GAN的人纷纷转来研究扩散模型。然而,许多人在学习Stable Diffusion时却犯了难:又是公式扎堆的扩散模型,又是VAE,又是U-Net,这该怎么学起呀?

其实,一上来就读Stable Diffusion是很难读懂的。而如果你把之前的一些更基础的文章读懂,再回头来读Stable Diffusion,就会畅行无阻了。在这篇及之后的几篇文章中,我将从科研的角度对Stable Diffusion做一个全面的解读。在第一篇文章中,我将面向完全没接触过图像生成的读者,从头介绍Stable Diffusion是怎样从早期工作中一步一步诞生的;在第二篇文章中,我将详细解读Stable Diffusion的论文;在最后的第三篇文章中,我将带领大家阅读Stable Diffusion的官方源码,以及一些流行的开源库的Stable Diffusion实现。后续我还会写其他和Stable Diffusion相关的文章,比如ControlNet的介绍。

从自编码器谈起

包括Stable Diffusion在内,很多图像生成模型都可以看成是一种非常简单的模型——自编码器——的改进版。要谈Stable Diffusion是怎么逐渐诞生的,其实就是在谈自编码器是一步一步进化的。我们的学习就从自编码器开始。

尽管PNG、JPG等图像压缩方法已经非常成熟,但我们会想,会不会还有更好的图像压缩算法呢?图像压缩,其实就是找两个映射,一个把图片编码成压缩数据,另一个把压缩数据解码回图片。我们知道,神经网络理论上可以拟合任何映射。那我们干脆用两个神经网络来拟合两种映射,以实现一个图像压缩算法。负责编码的神经网络叫编码器(Encoder),负责解码的神经网络叫做解码器(Decoder)

光定义了神经网络还不够,我们还需要给两个神经网络设置一个学习目标。在运行过程中,神经网络应该满足一个显然的约束:编码再解码后的重建图像应该和原图像尽可能一致,即二者的均方误差应该尽可能小。这样,我们只需要随便找一张图片,通过编码器和解码器得到重建图像,就能训练神经网络了。我们不需要给图片打上标签,整个训练过程是自监督的。所以我们说,整套模型是一个自编码器(Autoencoder,AE)

图像压缩模型AE为什么会和图像生成扯上关系呢?你可以试着把AE的输入图像和编码器遮住,只看解码部分。把一个压缩数据解码成图像,换个角度看,不就是在根据某一数据生成图像嘛。

很可惜,AE并不是一个合格的图像生成模型。我们常说的图像生成,具体是指让程序生成各种各样的图片。为了让程序生成不同的图片,我们一般是让程序根据随机数(或是随机向量)来生成图片。而普通的AE会有过拟合现象,这导致AE的解码器只认得训练集里的图片经编码器解码出来的压缩数据,而不认得随机生成的压缩数据,进而也无法达到图像生成的要求。

所谓过拟合,就是指模型只能处理训练数据,而不能推广到一般的数据上。举一个极端的例子,如下图所示,编码器和解码器直接记忆了整个数据集,把所有图片压缩成了一个数字。也就是模型把编码器当成一个图片到数字的词典,把解码器当成一个数字到图片的词典。这样,不管数据集有多大,所有图片都可以被压缩成一个数字。这样的AE确实压缩能力很强,但它完全没用,因为它过拟合了,处理不了训练集以外的数据。

过拟合现象在普通版AE中是不可避免的。为了利用AE的解码器来生成图片,许多工作都在试图克服AE的过拟合现象。AE的改进思路很多,在这篇文章中,我们仅把AE的改进路线粗略地分成两种:解决过拟合问题以直接用AE做图像生成、用AE压缩图像间接实现图像生成。

第一条路线:VAE 和 DDPM

在第一条改进路线中,许多后续工作都试图用更高级的数学模型来解决AE的过拟合问题。变分自编码器(Variational Autoencoder, VAE) 就是其中的代表。

VAE对AE做了若干改动。第一,VAE让编码器的输出不再是一个确定的数据,而是一个正态分布中的一个随机数据。更具体一点,训练时,编码器会同时输出一个均值和方差。随后,模型会从这个均值和方差表达的正态分布里随机采样一个数据,作为解码器的输入。直观上看,这一改动就是在AE的基础上,让编码器多输出了一个方差,使得原AE编码器的输出发生了一点随机扰动。

这一改动可以缓解过拟合现象。这是为什么呢?我们可以这样想:原来的AE之所以会过拟合,是因为它强行记住了训练集里每一个数据的编码输出。现在,我们在VAE里让编码器不再输出一个固定值,而是随机输出一个在均值附近的值。这样的话,VAE就不能死记硬背了,必须要找出数据中的规律。

VAE的第二项改动是多添加一个学习目标,让编码器的输出和标准正态分布尽可能相似。前面我们谈过,图像生成模型一般会根据一个随机向量来生成图像。最常用的产生随机向量的方法是去标准正态分布里采样。也就是说,在用VAE生成图像时,我们会抛掉编码器,用下图所示的流程来生成图像。如果我们不约束编码器的输出分布,不让它输出一个和标准正态分布很相近的分布的话,解码器就不能很好地根据来自标准正态分布的随机向量生成图像了。

综上,VAE对AE做了两项改进:使编码器输出一个正态分布,且该分布要尽可能和标准正态分布相似。训练时,模型从编码器输出的分布里随机采样一个数据作为解码器的输入;图像采样(图像生成)时,模型从标准正态分布里随机采样一个数据作为解码器的输入。VAE的误差函数由两部分组成:原图像和重建图像的重建误差、编码器输出和标准正态分布之间的误差。VAE要最小化重建误差,最大化编码器输出与标准正态分布的相似度。

分布与分布之间的误差可以用一个叫KL散度的指标表示。所以,在上面那个误差函数公式中,负的相似度应该被替换成KL散度。VAE的这两项改动本质上都是在解决AE的过拟合问题,所以,VAE的改动可以被看成一种正则化方法。我们可以把VAE的正则化方法简称为KL正则化

在机器学习中,正则化方法就是「降低模型过拟合的方法」的简称。

VAE确实能减轻AE的过拟合。然而,由于VAE只是让重建图像和原图像的均方误差(重建误差)尽可能小,而没有对重建图像的质量施加更多的约束,VAE的重建结果和图像生成结果都非常模糊。以下是VAE在CelebA数据集上图像生成结果。

在众多对VAE的改进方法中,一个叫做去噪扩散概率模型(Denoising Diffusion Probabilistic Model, DDPM) 的图像生成模型脱颖而出。DDPM正是当今扩散模型的开山鼻祖。我们来看一下DDPM是怎样基于VAE对图像生成建模的。

VAE之所以效果不好,很可能是因为它的约束太少了。VAE的编码和解码都是用神经网络表示的。神经网络是一个黑盒,我们不好对神经网络的中间步骤施加约束,只好在编码器的输出(某个正态分布)和解码器的输出(重建图像)上施加约束。能不能让VAE的编码和解码过程更可控一点呢?

DDPM的设计灵感来自热力学:一个分布可以通过一系列简单的变化(如添加高斯噪声)逐渐变成另一个分布。恰好,VAE的编码器不正是想让来自训练集的图像(训练集分布)变成标准正态分布吗?既然如此,就不要用一个可学习的神经网络来表示VAE的编码器了,干脆用一些预定义好的加噪声操作来表示解码过程。可以从数学上证明,经过了多次加噪声操作后,最后的图像分布会是一个标准正态分布。

既然编码是加噪声,那解码时就应该去掉噪声。DDPM的解码器也不再是一个不可解释的神经网络,而是一个能预测若干个去噪结果的神经网络。

相比只有两个约束条件的VAE,DDPM的约束条件就多得多了。在DDPM中,第t个去噪操作应该尽可能抵消掉第t个加噪操作。

让我们来更具体地认识一下DDPM的学习目标。所谓添加噪声,就是在一个均值约等于当前图像的正态分布上采样。比如要对图像$\mathbf{x}$添加噪声,我们可以在$\mathcal{N}(0.9\mathbf{x},\mathbf{I})$这个分布里采样一张新图像。新的图像每个像素的均值是原来的0.9倍左右,且新图像会出现很多噪声。我们设$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$为第$t$步加噪声的正态分布。经过一些数学推导,我们可以求出这一步操作的逆操作$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$,这个加噪声逆操作也是一个正态分布。既然如此,我们可以设第$t$步去噪声也为一个正态分布$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$,让第$t$步去噪声和第$t$步加噪声的逆操作尽可能相似。

总结一下,DDPM对VAE做了如下改动:

  1. 编码器是一系列不可学习(固定)的加噪声操作
  2. 解码器是一系列可学习的去噪声操作
  3. 图像尺寸自始至终不变

相比于VAE,DDPM的编码过程和解码过程的定义更加明确,可以施加的约束更多。因此,如下图所示,它的生成效果会比VAE好很多。同时,DDPM和VAE类似,它在编码时会从分布里采样,而不是只输出一个固定值,不会出现AE的过拟合问题。

DDPM的图像生成结果

DDPM的生成效果确实很好。但是,由于DDPM始终会对同一个尺寸的数据进行操作,图像的尺寸极大地影响了DDPM的运行速度,用DDPM生成高分辨率图像需要耗费大量计算资源。因此,想要用DDPM生成高质量图像,还得经过另一条路线。

第二条路线:VQVAE

在AE的第二条改进路线中,一些工作干脆放弃使用AE做图像生成,转而利用AE的图像压缩能力,把图像生成拆成两步来做:先用AE的编码器把图像压缩成更小的图像,再用另一个图像生成模型生成小图像,并用AE的解码器把小图像重建回真实图像。

为什么会有这么奇怪的图像生成方法呢?这得从另一类图像生成模型讲起。在机器翻译模型Transformer横空出世后的一段时间里,有很多工作都想把Transformer用在图像生成上。但是,原本用来生成文本的Transformer无法直接应用在图像上。在自然语言处理(NLP)中,一个句子可以用若干个单词表示。而每个单词又是用一个整数表示。所以,Transformer生成句子时,实际上是在生成若干个离散的整数,也就是生成一个离散向量。而在图像生成模型中,每个像素的颜色值是一个连续的浮点数。想把Transformer直接用在图像生成上,就得想办法把图像用离散向量表示。我们知道,AE可以把图像编码成一个连续向量。能不能做一些修改,让AE把图像编码成一个离散向量呢?

Vector Quantised-Variational AutoEncoder (VQVAE) 就是一个能把图像编码成离散向量的AE(虽然作者在取名时用了VAE)。我们来简单看一下VQVAE是怎样把图像编码成离散向量的。

假设我们有了一个能编码出离散向量的AE。

由于神经网络不能很好地处理离散数据,我们要引入NLP里的通常做法,加一个把离散向量映射成连续向量的嵌入层。

现在我们再回头讨论怎么让编码器输出一个离散向量。我们可以让AE的解码器保持不变,还是输出一个连续向量,再通过一个「向量离散化」操作,把连续向量变成离散向量。这个操作会把编码器的输出对齐到嵌入层的向量上,其原理类似于把0.99和1.01离散化成1,只不过它是对向量整体考虑,而不是对每一个数单独考虑。向量离散化操作的具体原理我们不在此处细究。

忽略掉实现细节,我们可以认为VQVAE能够把图像压缩成离散向量。更准确地说,VQVAE能把图像等比例压缩成离散的「小图像」。压缩成二维图像而不是一维向量,能够保留原图像的一些空间特性,为之后第二步图像生成铺路。

整理一下,VQVAE是一个能把图像压缩成离散小图像的AE。为了用VQVAE生成图像,需要执行一个两阶段的图像生成流程:

  • 训练时,先训练一个图像压缩模型(VQVAE),再训练一个生成压缩图像的模型(比如Transformer)
  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型的解码器把压缩图像复原成真实图像

之所以要执行两阶段的图像生成流程,而不是只用第二个模型生成大图像,有两个原因。第一个原因是前面提到的,Transformer等生成模型只支持生成离散图像,需要用另一个模型把连续的颜色值变成离散值以兼容这些模型。第二个原因是为了减少模型的运算量。以Transformer为例,Transformer的运算次数大致与像素数的平方成正比,拿Transformer生成高分辨率图像的运算开销是不可接受的。而如果用一个AE把图像压缩一下的话,用Transformer就可行了。

VQVAE给后续工作带来了三条启发:第一,可以用AE把图像压缩成离散向量;第二,如果一个图像生成模型生成高分辨率的图像的计算代价太高,可以先用AE把图像压缩,再生成压缩图像。这两条启发对应上一段提到的使用VQVAE的两条动机。

而第三条启发就比较有意思了。在讨论VQVAE的过程中,我们完全没有考虑过拟合的事。这是因为经过了向量离散化操作后,解码器的输入已经不再是编码器的输出,而是嵌入层里的向量了。这种做法杜绝了AE的死记硬背,缓解了过拟合现象。这样,我们可以换一个角度看待VQVAE:编码器还是AE的编码器,编码器的输出是连续向量,后续的向量离散化操作和嵌入层全部都是解码器的一部分。从这个角度看,VQVAE其实提出了一个由向量离散化和嵌入层组成的正则化模块。这个模块和VAE的KL散度约束一样,都解决了AE的过拟合问题。我们把VQVAE的正则化方法叫做VQ正则化

VQVAE论文提出的图像生成方法效果一般。和普通的AE一样,VQVAE在训练时只用了重建误差来约束图像质量,重建图像的细节依然很模糊。且VQVAE配套的第二阶段图像生成模型不是较为强力的Transformer,而是一个基于CNN的图像生成模型。

后续的VQGAN论文对VQVAE进行了改进。对于一阶段的图像压缩模型,VQGAN在VQVAE的基础上引入了生成对抗网络(GAN)中一些监督误差,提高了图像压缩模型的重建质量;对于两阶段的图像生成模型,该方法使用了Transformer。凭借这些改动,VQGAN方法能够生成高质量的高清图片。并且,通过把额外的约束条件(如语义分割图像、文字)输入进Transformer,VQGAN方法能够实现带约束的图像生成。以下是VQGAN方法根据语义分割图像生成的高清图片。

图像生成模型可以是无约束或带约束的。无约束图像生成模型只需要输入一个随机向量,训练数据不需要任何标注,可以进行无监督训练。带约束图像生成模型会在无约束图像生成模型的基础上多加一些输入,并给每个训练图像打上描述约束的标签,执行监督训练。比如要训练文生图模型,就要给每个训练图片带上文字描述。

路线的交汇点——Stable Diffusion

看完上面这两条AE的改进路线,相信你已经能够猜出Stable Diffusion的核心思想了。让我们看看Stable Diffusion是怎么从这两条路径中汲取灵感的。

在发布了VQGAN后,德国的CompVis实验室开始探索起VQGAN的改进方法。VQGAN能把图像边长压缩16倍,而VQGAN配套的Transformer只能一次生成$16 \times 16$的图片。也就是说,整套方法一次只能生成$256 \times 256$的图片。为了生成分辨率更高的图片,VQGAN方法需要借助滑动窗口。能不能让模型一次性生成分辨率更高的图片呢?制约VQGAN方法生成分辨率的主要因素是Transformer。如果能把Transformer换成一个效率更高,能生成更高分辨率的图像的模型,不就能生成比$256\times256$更大的图片了吗?CompVis实验室开始把目光着眼于DDPM上。

于是,在发布VQGAN的一年后,CompVis实验室又发布了名为High-Resolution Image Synthesis with Latent Diffusion Models的论文,提出了一种叫做隐扩散模型(latent diffusion model, LDM) 的图像生成模型。通过与AI公司Stability AI合作,借助他们庞大的算力资源训练LDM,CompVis实验室发布了商业名为Stable Diffusion的开源文生图AI绘画模型。

LDM其实就是在VQGAN方法的基础上,把图像生成模型从Transformer换成了DDPM。或者从另一个角度说,为了让DDPM生成高分辨率图像,LDM利用了VQVAE的第二条启发:先用AE把图像压缩,再用DDPM生成压缩图像。LDM的AE一般是把图像边长压缩8倍,DDPM生成$64 \times 64$的压缩图像,整套LDM能生成$512 \times 512$的图像。

和Transformer不同,DDPM处理的图像是用连续向量表示的。因此,在LDM中使用VQGAN做图像压缩时,不一定需要向量离散化操作,只需要在AE的基础上加一点轻微的正则化就行。作者在实现LDM时讨论了两类正则化,一类是VAE的KL正则化,一类是VQ正则化(对应VQVAE的第三条启发),两种正则化都能取得不错的效果。

LDM依然可以实现带约束的图像生成。用DDPM替换掉Transformer后,额外的约束会输入进DDPM中。作者在论文中讨论了几种把约束输入进DDPM的方式。

在搞懂了早期工作后,理解Stable Diffusion的核心思想就是这么简单。让我们把Stable Diffusion的发展过程及主要结构总结一下。Stable Diffusion由两类AE的变种发展而来,一类是有强大生成能力却需要耗费大量运算资源的DDPM,一类是能够以较高保真度压缩图像的VQVAE。Stable Diffusion是一个两阶段的图像生成模型,它先用一个使用KL正则化或VQ正则化的VQGAN来实现图像压缩,再用DDPM生成压缩图像。可以把额外的约束(如文字)输入进DDPM以实现带约束图像生成。

相关论文

本文仅仅对Stable Diffusion的早期工作做了一个简单的梳理。要把Stable Diffusion吃透,还需要多读一些早期论文。我来把早期论文按重要性分个类。

图像生成必读文章

Neural Discrete Representation Learning (VQVAE): https://arxiv.org/abs/1711.00937

Taming Transformers for High-Resolution Image Synthesis (VQGAN): https://arxiv.org/abs/2012.09841

Denoising Diffusion Probabilistic Models (DDPM): https://arxiv.org/abs/2006.11239

图像生成选读文章

Auto-Encoding Variational Bayes (VAE): https://arxiv.org/abs/1312.6114 提出VAE的文章。数学公式较多,只需要了解VAE的大致结构就好,不需要详细阅读论文。

Pixel Recurrent Neural Networks (PixelCNN): https://arxiv.org/abs/1601.06759 提出了一种拟合离散分布的图像生成模型,自回归图像生成模型的代表。这是VQVAE使用的第二阶段图像生成模型。有兴趣可以了解一下。

Deep Unsupervised Learning using Nonequilibrium Thermodynamics: https://arxiv.org/abs/1503.03585 DDPM的前作,首个提出扩散模型思想的文章。其核心原理和DDPM几乎完全一致,但是模型结构和优化目标不够先进,生成效果没有改进后的DDPM好。数学公式较多,不必细读,可以在学习DDPM时对比着阅读。

Denoising Diffusion Implicit Models (DDIM): https://arxiv.org/abs/2010.02502 一种加速DDPM采样的方法,广泛运用在包含Stable Diffusion在内的扩散模型中。推荐阅读。

Classifier-Free Diffusion Guidance: https://arxiv.org/abs/2207.12598 一种让扩散模型的输出更加贴近约束的方法,广泛运用在包含Stable Diffusion在内的扩散模型中,用于生成更符合文字描述的图片。推荐阅读。

Generative Adversarial Networks (GAN): https://arxiv.org/abs/1406.2661 以及 A Style-Based Generator Architecture for Generative Adversarial Networks (StyleGAN): https://arxiv.org/abs/1812.04948 可以了解一下GAN是怎么确保图像生成质量的,并认识CelebAHQ和FFHQ这两个常用的人脸数据集。

其他必读文章

Deep Residual Learning for Image Recognition (ResNet): https://arxiv.org/abs/1512.03385 深度学习的经典文章。其中提出的残差连接被用到了DDPM中。

Attention Is All You Need (Transformer): https://arxiv.org/abs/1706.03762 深度学习的经典文章。其中提出的自注意力模块被用到了DDPM中。

其他选读文章

Learning Transferable Visual Models From Natural Language Supervision (CLIP): https://arxiv.org/abs/2103.00020 提出了对齐文本和图像的方法。绝大多数文生图模型的核心。

U-Net: Convolutional Networks for Biomedical Image Segmentation (U-Net): https://arxiv.org/abs/1505.04597 一种被广泛运用的神经网络架构。DDPM的神经网络的主架构。U-Net的结构很简单,可以不用去读论文,直接看代码。

我的解读文章

我对这上面的很多论文都做过解读。如果你在阅读论文的时候碰到了困难,欢迎阅读我的解读。

轻松理解 VQ-VAE:首个提出 codebook 机制的生成模型

VQGAN 论文与源码解读:前Diffusion时代的高清图像生成模型

扩散模型(Diffusion Model)详解:直观理解、数学原理、PyTorch 实现

抛开数学,轻松学懂 VAE(附 PyTorch 实现)

冷门的自回归生成模型 ~ 详解 PixelCNN 大家族

DDIM 简明讲解与 PyTorch 实现:加速扩散模型采样的通用方法

用18支画笔作画的AI ~ StyleGAN特点浅析

ResNet 论文概览与精读

Attention Is All You Need (Transformer) 论文精读

相比于多数图像生成模型,去噪扩散概率模型(Denoising Diffusion Probabilistic Model, DDPM)的采样速度非常慢。这是因为DDPM在采样时通常要做1000次去噪操作。但如果你玩过基于扩散模型的图像生成应用的话,你会发现,大多数应用只需要20次去噪即可生成图像。这是为什么呢?原来,这些应用都使用了一种更快速的采样方法——去噪扩散隐式模型(Denoising Diffusion Implicit Model, DDIM)。

基于DDPM,DDIM论文主要提出了两项改进。第一,对于一个已经训练好的DDPM,只需要对采样公式做简单的修改,模型就能在去噪时「跳步骤」,在一步去噪迭代中直接预测若干次去噪后的结果。比如说,假设模型从时刻$T=100$开始去噪,新的模型可以在每步去噪迭代中预测10次去噪操作后的结果,也就是逐步预测时刻$t=90, 80, …, 0$的结果。这样,DDPM的采样速度就被加速了10倍。第二,DDIM论文推广了DDPM的数学模型,从更高的视角定义了DDPM的前向过程(加噪过程)和反向过程(去噪过程)。在这个新数学模型下,我们可以自定义模型的噪声强度,让同一个训练好的DDPM有不同的采样效果。

在这篇文章中,我将言简意赅地介绍DDIM的建模方法,并给出我的DDIM PyTorch实现与实验结果。本文不会深究DDIM的数学推导,对这部分感兴趣的读者可以去阅读我在文末给出的参考资料。

回顾 DDPM

DDIM是建立在DDPM之上的一篇工作。在正式认识DDIM之前,我们先回顾一下DDPM中的一些关键内容,再从中引出DDIM的改进思想。

DDPM是一个特殊的VAE。它的编码器是$T$步固定的加噪操作,解码器是$T$步可学习的去噪操作。模型的学习目标是让每一步去噪操作尽可能抵消掉对应的加噪操作。

DDPM的加噪和去噪操作其实都是在某个正态分布中采样。因此,我们可以用概率$q, p$分别表示加噪和去噪的分布。比如 $q(\mathbf{x}_t|\mathbf{x}_{t-1})$ 就是由第 $t-1$ 时刻的图像到第 $t$ 时刻的图像的加噪声分布, $p(\mathbf{x}_{t-1}|\mathbf{x}_{t})$ 就是由第 $t$ 时刻的图像到第 $t-1$ 时刻的图像的去噪声分布。这样,我们可以说网络的学习目标是让 $p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$ 尽可能与 $q(\mathbf{x}_t | \mathbf{x}_{t-1})$ 和互逆。

但是,「互逆」并不是一个严格的数学表述。更具体地,我们应该让分布$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$和分布$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$尽可能相似。$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$和$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$的关系就和VAE中原图像与重建图像的关系一样。

$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$是不好求得的,但在给定了输入数据$\mathbf{x}_{0}$时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})$是可以用贝叶斯公式求出来的:

我们不必关心具体的求解方法,只需要知道从等式右边的三项$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)、q(\mathbf{x}_{t-1} | \mathbf{x}_0)、q(\mathbf{x}_{t} | \mathbf{x}_0)$可以推导出等式左边的那一项。在DDPM中,$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$是一个定义好的式子,且$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$。根据$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$,可以推出$q(\mathbf{x}_{t} | \mathbf{x}_0)$。知道了$q(\mathbf{x}_{t} | \mathbf{x}_0)$,$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$也就知道了(把公式里的$t$换成$t-1$就行了)。这样,在DDPM中,等式右边的式子全部已知,等式左边的$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})$可以直接求出来。

上述推理过程可以简单地表示为:知道$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$,就知道了神经网络的学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$。这几个公式在DDPM中的具体形式如下:

其中,只有参数$\beta_t$是可调的。$\bar{\alpha}_t$是根据$\beta_t$算出的变量,其计算方法为:$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$。

由于学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$里只有一个未知变量$\mathbf{x}_0$,DDPM把学习目标简化成了只让神经网络根据$\mathbf{x}_{t}$拟合公式里的$\mathbf{x}_{0}$(更具体一点,是拟合从$\mathbf{x}_{0}$到$\mathbf{x}_{t}$的噪声)。也就是说,在训练时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$的公式不会被用到,只有$\mathbf{x}_{t}$和$\mathbf{x}_{0}$两个量之间的公式$q(\mathbf{x}_{t} | \mathbf{x}_0)$会被用到。只有在采样时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$的公式才会被用到。训练目标的推理过程可以总结为:

理解「DDPM的训练目标只有$\mathbf{x}_{0}$」对于理解DDIM非常关键。如果你在回顾DDPM时出现了问题,请再次阅读DDPM的相关介绍文章。

加速 DDPM

我们再次审视一下DDPM的推理过程:首先有$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$。根据$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$,可以推出$q(\mathbf{x}_{t} | \mathbf{x}_0)$。知道$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$,由贝叶斯公式,就知道了学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$。

根据这一推理过程,DDIM论文的作者想到,假如我们把贝叶斯公式中的$t$替换成$t_2$, $t-1$替换成$t_1$,其中$t_2$是比$t_1$大的任意某一时刻,那么我们不就可以从$t_2$到$t_1$跳步骤去噪了吗?比如令$t_2 = t_1 + 10$,我们就可以求出去除10次噪声的公式,去噪的过程就快了10倍。

修改之后,$q(\mathbf{x}_{t_1} | \mathbf{x}_0)$和$q(\mathbf{x}_{t_2} | \mathbf{x}_0)$依然很好求,只要把$t_1$, $t_2$代入普通的$q(\mathbf{x}_{t} | \mathbf{x}_0)$公式里就行。

但是,$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$怎么求呢?原来的$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$来自于DDPM的定义,我们能直接把公式拿来用。能不能把$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的公式稍微修改一下,让它兼容$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$呢?

修改$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的思路如下:假如我们能把公式中的$\beta_t$换成一个由$t$和$t-1$决定的变量,我们就能把$t$换成$t_2$,$t-1$换成$t_1$,也就得到了$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$。

那怎么修改$\beta_t$的形式呢?很简单。我们知道$\beta_t$决定了$\bar{\alpha}_t$:$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$。那么我们用$\bar{\alpha}_t$除以$\bar{\alpha}_{t-1}$,不就得到了$1-\beta_t$了吗?也就是说:

我们把这个用$\bar{\alpha}_t$和$\bar{\alpha}_{t-1}$表示的$\beta_t$套入$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的公式里,再把$t$换成$t_2$,$t-1$换成$t_1$,就得到了$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$。有了这一项,贝叶斯公式等式右边那三项我们就全部已知,可以求出$q(\mathbf{x}_{t_1} | \mathbf{x}_{t_2}, \mathbf{x}_0)$,也就是可以一次性得到多个时刻后的去噪结果。

在这个过程中,我们只是把DDPM公式里的$\bar{\alpha}_t$换成$\bar{\alpha}_{t2}$,$\bar{\alpha}_{t-1}$换成$\bar{\alpha}_{t1}$,公式推导过程完全不变。网络的训练目标$\mathbf{x}_{0}$也没有发生改变,只是采样时的公式需要修改。这意味着我们可以先照着原DDPM的方法训练,再用这种更快速的方式采样。

我们之前只讨论了$t_1$到$t_2$为固定值的情况。实际上,我们不一定要间隔固定的时刻去噪一次,完全可以用原时刻序列的任意一个子序列来去噪。比如去噪100次的DDPM的去噪时刻序列为[99, 98, ..., 0],我们可以随便取一个长度为10的子序列:[99, 98, 77, 66, 55, 44, 33, 22, 1, 0],按这些时刻来去噪也能让采样速度加速10倍。但实践中没人会这样做,一般都是等间距地取时刻。

这样看来,在采样时,只有部分时刻才会被用到。那我们能不能顺着这个思路,干脆训练一个有效时刻更短(总时刻$T$不变)的DDPM,以加速训练呢?又或者保持有效训练时刻数不变,增大总时刻$T$呢?DDIM论文的作者提出了这些想法,认为这可以作为后续工作的研究方向。

从 DDPM 到 DDIM

除了加速DDPM外,DDIM论文还提出了一种更普遍的DDPM。在这种新的数学模型下,我们可以任意调节采样时的方差大小。让我们来看一下这个数学模型的推导过程。

DDPM的学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$由$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$决定。具体来说,在求解正态分布$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$时,我们会将它的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$设为未知量,并将条件$q(\mathbf{x}_{t} | \mathbf{x}_0)$、$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$、$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$代入,求解出确定的$\tilde{\mu}_t$和$\tilde{\beta}_t$。

在上文我们分析过,DDPM训练时只需要拟合$\mathbf{x}_0$,只需要用到$\mathbf{x}_0$和$\mathbf{x}_t$的关系$q(\mathbf{x}_{t} | \mathbf{x}_0)$。在不修改训练过程的前提下,我们能不能把限制$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$去掉(即$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$可以是任意一个正态分布,而不是我们提前定义好的一个正态分布),得到一个更普遍的DDPM呢?

这当然是可以的。根据基础的解方程知识,我们知道,去掉一个方程后,会多出一个自由变量。取消了$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的限制后,均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$就不能同时确定下来了。我们可以令方差$\tilde{\beta}_t$为自由变量,并让$\tilde{\mu}_t$用含$\tilde{\beta}_t$的式子表示出来。这样,我们就得到了一个方差可变的更一般的DDPM。

让我们来看一下这个新模型的具体公式。原来的DDPM的加噪声逆操作的分布为:

新的分布公式为:

新公式是旧公式的一个推广版本。如果我们把DDPM的方差$(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$代入新公式里的$\tilde{\beta}_t$,就能把新公式还原成DDPM的公式。和DDPM的公式一样,我们也可以把$\mathbf{x}_{0}$拆成$\mathbf{x}_{t}$和噪声$\epsilon$表示的式子。

现在采样时方差可以随意取了,我们来讨论一种特殊的方差取值——$\tilde{\beta}_t=0$。也就是说,扩散模型的反向过程变成了一个没有噪声的确定性过程。给定随机噪声$\mathbf{x}_{T}$,我们只能得到唯一的采样结果$\mathbf{x}_{0}$。这种结果确定的概率模型被称为隐式概率模型(implicit probabilistic model)。所以,论文作者把方差为0的这种扩散模型称为DDIM(Denoising Diffusion Implicit Model)。

为了方便地选取方差值,作者将方差改写为

其中,$\eta\in[0, 1]$。通过选择不同的$\eta$,我们实际上是在DDPM和DDIM之间插值。$\eta$控制了插值的比例。$\eta=0$,模型是DDIM;$\eta=1$,模型是DDPM。

除此之外,DDPM论文曾在采样时使用了另一种方差取值:$\tilde{\beta}_t=\beta_t$,即去噪方差等于加噪方差。实验显示这个方差的采样结果还不错。我们可以把这个取值也用到DDIM论文提出的方法里,只不过这个方差值不能直接套进上面的公式。在代码实现部分我会介绍该怎么在DDIM方法中使用这个方差。

注意,在这一节的推导过程中,我们依然没有修改DDPM的训练目标。我们可以把这种的新的采样方法用在预训练的DDPM上。当然,我们可以在使用新的采样方法的同时也使用上一节的加速采样方法。

实验

到这里为止,我们已经学完了DDIM论文的两大内容:加速采样、更换采样方差。加速采样的意义很好理解,它能大幅减少采样时间。可更换采样方差有什么意义呢?我们看完论文中的实验结果就知道了。

论文展示了新采样方法在不同方差、不同采样步数下的FID指标(越小越好)。其中,$\hat{\sigma}$表示使用DDPM中的$\tilde{\beta}_t=\beta_t$方差取值。实验结果非常有趣。在使用采样加速(步数比总时刻1000要小)时,$\eta=0$的DDIM的表现最好,而$\hat{\sigma}$的情况则非常差。而当$\eta$增大,模型越来越靠近DDPM时,用$\hat{\sigma}$的结果会越来越好。而在DDPM中,用$\hat{\sigma}$的结果是最好的。

从这个实验结果中,我们可以得到一条很简单的实践指南:如果使用了采样加速,一定要用效果最好的DDIM;而使用原DDPM的话,可以维持原论文提出的$\tilde{\beta}_t=\beta_t$方差取值。

总结

DDIM论文提出了DDPM的两个拓展方向:加速采样、变更采样方差。通过同时使用这两个方法,我们能够在不重新训练DDPM、尽可能不降低生成质量的前提下,让扩散模型的采样速度大幅提升(一般可以快50倍)。让我们再从头理一理提出DDIM方法的思考过程。

为了能直接使用预训练的DDPM,我们希望在改进DDPM时不更改DDPM的训练过程。而经过简化后,DDPM的训练目标只有拟合$\mathbf{x}_{0}$,训练时只会用到前向过程公式$q(\mathbf{x}_{t} | \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I})$。所以,我们的改进应该建立在公式$q(\mathbf{x}_{t} | \mathbf{x}_0)$完全不变的前提下。

通过对DDPM反向过程公式的简单修改,也就是把$t$改成$t_2$,$t-1$改成$t_1$,我们可以把去噪一步的公式改成去噪多步的公式,以大幅加速DDPM。可是,这样改完之后,采样的质量会有明显的下降。

我们可以猜测,减少了采样迭代次数后,采样质量之所以下降,是因为每次估计的去噪均值更加不准确。而每次去噪迭代中的噪声(由方差项决定的那一项)放大了均值的不准确性。我们能不能干脆让去噪时的方差为0呢?为了让去噪时的方差可以自由变动,我们可以去掉DDPM的约束条件。由于贝叶斯公式里的$q(\mathbf{x}_{t} | \mathbf{x}_0)$不能修改,我们只能去掉$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的限制。去掉限制后,方差就成了自由变量。我们让去噪方差为0,让采样过程没有噪声。这样,就得到了本文提出的DDIM模型。实验证明,在采样迭代次数减少后,使用DDIM的生成结果是最优的。

在本文中,我较为严格地区分了DDPM和DDIM的叫法:DDPM指DDPM论文中提出的有1000个扩散时刻的模型,它的采样方差只有两种取值($\tilde{\beta}_t=(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$, $\tilde{\beta}_t=\beta_t$)。DDIM指DDIM论文中提出的$\eta=0$的推广版DDPM模型。DDPM和DDIM都可以使用采样加速。但是,从习惯上我们会把没有优化加速的DDPM称为”DDPM”,把$\eta$可以任取,采样迭代次数可以任取的采样方法统称为”DDIM”。一些开源库中会有叫DDIMSampler的类,调节$\eta$的参数大概会命名为eta,调节迭代次数的参数大概会命名为ddim_num_steps。一般我们令eta=0ddim_num_steps=20即可。

DDIM的代码实现没有太多的学习价值,只要在DDPM代码的基础上把新数学公式翻译成代码即可。其中唯一值得注意的就是如何在DDIM中使用DDPM的方差$\tilde{\beta}_t=\beta_t$。对此感兴趣的话可以阅读我接下来的代码实现介绍。

在这篇解读中,我略过了DDIM论文中的大部分数学推导细节。对DDIM数学模型的推导过程感兴趣的话,可以阅读我在参考文献中推荐的文章,或者看一看原论文。

DDIM PyTorch 实现

在这个项目中,我们将对一个在CelebAHQ上预训练的DDPM执行DDIM采样,尝试复现论文中的那个FID表格,以观察不同etaddim_steps对于采样结果的影响。

代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/ddim

DDPM 基础项目

DDIM只是DDPM的一种采样改进策略。为了复现DDIM的结果,我们需要一个DDPM基础项目。由于DDPM并不是本文的重点,在这一小节里我将简要介绍我的DDPM实现代码的框架。

我们的实验需要使用CelebAHQ数据集,请在 https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256 下载该数据集并解压到项目的data/celebA/celeba_hq_256目录下。另外,我在Hugging Face上分享了一个在64x64 CelebAHQ上训练的DDPM模型:https://huggingface.co/SingleZombie/dldemos/tree/main/ckpt/ddim ,请把它放到项目的dldemos/ddim目录下。

先运行dldemos/ddim/dataset.py下载MNIST,再直接运行dldemos/ddim/main.py,代码会自动完成MNIST上的训练,并执行步数1000的两种采样和步数20的三种采样,同时将结果保存在目录work_dirs中。以下是我得到的MNIST DDPM采样结果(存储在work_dirs/diffusion_ddpm_sigma_hat.jpg中)。

为了查看64x64 CelebAHQ上的采样结果,可以在dldemos/ddim/main.py的main函数里把config_id改成2,再注释掉训练函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 0 for MNIST. See configs.py
config_id = 2
cfg = configs[config_id]
n_steps = 1000
device = 'cuda'
model_path = cfg['model_path']
img_shape = cfg['img_shape']
to_bgr = False if cfg['dataset_type'] == 'MNIST' else True

net = UNet(n_steps, img_shape, cfg['channels'], cfg['pe_dim'],
cfg.get('with_attn', False), cfg.get('norm_type', 'ln'))
ddpm = DDPM(device, n_steps)

# train(ddpm,
# net,
# cfg['dataset_type'],
# resolution=(img_shape[1], img_shape[2]),
# batch_size=cfg['batch_size'],
# n_epochs=cfg['n_epochs'],
# device=device,
# ckpt_path=model_path)

以下是我得到的CelebAHQ DDPM采样结果(存储在work_dirs/diffusion_ddpm_sigma_hat.jpg中)。

项目目录下的configs.py存储了训练配置,dataset.py定义了DataLoadernetwork.py定义了U-Net的结构,ddpm.pyddim.py分别定义了普通的DDPM前向过程和采样以及DDIM采样,dist_train.py提供了并行训练脚本,dist_sample.py提供了并行采样脚本,main.py提供了单卡运行的所有任务脚本。

在这个项目中,我们的主要的目标是基于其他文件,编写ddim.py。我们先来看一下原来的DDPM类是怎么实现的,再仿照它的接口写一个DDIM类。

实现 DDIM 采样

在我的设计中,DDPM类不是一个神经网络(torch.nn.Module),它仅仅维护了扩散模型的alpha等变量,并描述了前向过程和反向过程。

DDPM类中,我们可以在初始化函数里定义好要用到的self.betas, self.alphas, self.alpha_bars变量。如果在工程项目中,我们可以预定义好更多的常量以节约采样时间。但在学习时,我们可以少写一点代码,让项目更清晰一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class DDPM():

def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
alphas = 1 - betas
alpha_bars = torch.empty_like(alphas)
product = 1
for i, alpha in enumerate(alphas):
product *= alpha
alpha_bars[i] = product
self.betas = betas
self.n_steps = n_steps
self.alphas = alphas
self.alpha_bars = alpha_bars

前向过程就是把正态分布的公式$q(\mathbf{x}_{t} | \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I})$翻译一下。

1
2
3
4
5
6
def sample_forward(self, x, t, eps=None):
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
if eps is None:
eps = torch.randn_like(x)
res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
return res

在反向过程中,我们从self.n_steps1枚举时刻t(代码中时刻和数组下标有1的偏差),按照公式算出每一步的去噪均值和方差,执行去噪。算法流程如下:

参数simple_var=True表示令方差$\sigma_t^2=\beta_t$,而不是$(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def sample_backward(self, img_or_shape, net, device, simple_var=True):
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
net = net.to(device)
for t in tqdm(range(self.n_steps - 1, -1, -1), "DDPM sampling"):
x = self.sample_backward_step(x, t, net, simple_var)

return x

def sample_backward_step(self, x_t, t, net, simple_var=True):

n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

接下来,我们来实现DDIM类。DDIMDDPM的推广,我们可以直接用DDIM类继承DDPM类。它们共享初始化函数与前向过程函数。

1
2
3
4
5
6
7
8
class DDIM(DDPM):

def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
super().__init__(device, n_steps, min_beta, max_beta)

我们要修改的只有反向过程的实现函数。整个函数的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def sample_backward(self,
img_or_shape,
net,
device,
simple_var=True,
ddim_step=20,
eta=1):
if simple_var:
eta = 1
ts = torch.linspace(self.n_steps, 0,
(ddim_step + 1)).to(device).to(torch.long)
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
batch_size = x.shape[0]
net = net.to(device)
for i in tqdm(range(1, ddim_step + 1),
f'DDIM sampling with eta {eta} simple_var {simple_var}'):
cur_t = ts[i - 1] - 1
prev_t = ts[i] - 1

ab_cur = self.alpha_bars[cur_t]
ab_prev = self.alpha_bars[prev_t] if prev_t >= 0 else 1

t_tensor = torch.tensor([cur_t] * batch_size,
dtype=torch.long).to(device).unsqueeze(1)
eps = net(x, t_tensor)
var = eta * (1 - ab_prev) / (1 - ab_cur) * (1 - ab_cur / ab_prev)
noise = torch.randn_like(x)

first_term = (ab_prev / ab_cur)**0.5 * x
second_term = ((1 - ab_prev - var)**0.5 -
(ab_prev * (1 - ab_cur) / ab_cur)**0.5) * eps
if simple_var:
third_term = (1 - ab_cur / ab_prev)**0.5 * noise
else:
third_term = var**0.5 * noise
x = first_term + second_term + third_term

return x

我们来把整个函数过一遍。先看一下函数的参数。相比DDPM,DDIM的采样会多出两个参数:ddim_step, eta。如正文所述,ddim_step表示执行几轮去噪迭代,eta表示DDPM和DDIM的插值系数。

1
2
3
4
5
6
7
def sample_backward(self,
img_or_shape,
net,
device,
simple_var=True,
ddim_step=20,
eta=1):

在开始迭代前,要做一些预处理。根据论文的描述,如果使用了DDPM的那种简单方差,一定要令eta=1。所以,一开始我们根据simple_vareta做一个处理。之后,我们要准备好迭代时用到的时刻。整个迭代过程中,我们会用到从self.n_steps0等间距的ddim_step+1个时刻(self.n_steps是初始时刻,不在去噪迭代中)。比如总时刻self.n_steps=100ddim_step=10ts数组里的内容就是[100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0]

1
2
3
4
5
6
7
8
9
10
if simple_var:
eta = 1
ts = torch.linspace(self.n_steps, 0,
(ddim_step + 1)).to(device).to(torch.long)
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
batch_size = x.shape[0]
net = net.to(device)

做好预处理后,进入去噪循环。在for循环中,我们从1ddim_step遍历ts的下标,从时刻数组ts里取出较大的时刻cur_t(正文中的$t_2$)和较小的时刻prev_t(正文中的$t_1$)。由于self.alpha_bars存储的是t=1, t=2, ..., t=n_steps时的变量,时刻和数组下标之间有一个1的偏移,我们要把ts里的时刻减去1得到时刻在self.alpha_bars里的下标,再取出对应的变量ab_cur, ab_prev。注意,在当前时刻为0时,self.alpha_bars是没有定义的。但由于self.alpha_bars表示连乘,我们可以特别地令当前时刻为0(prev_t=-1)时的alpha_bar=1

1
2
3
4
5
6
7
for i in tqdm(range(1, ddim_step + 1),
f'DDIM sampling with eta {eta} simple_var {simple_var}'):
cur_t = ts[i - 1] - 1
prev_t = ts[i] - 1

ab_cur = self.alpha_bars[cur_t]
ab_prev = self.alpha_bars[prev_t] if prev_t >= 0 else 1

准备好时刻后,我们使用和DDPM一样的方法,用U-Net估计生成x_t时的噪声eps,并准备好DDPM采样算法里的噪声noise(公式里的$\mathbf{z}$)。
与DDPM不同,在计算方差var时(公式里的$\sigma_t^2$),我们要给方差乘一个权重eta

1
2
3
4
5
t_tensor = torch.tensor([cur_t] * batch_size,
dtype=torch.long).to(device).unsqueeze(1)
eps = net(x, t_tensor)
var = eta * (1 - ab_prev) / (1 - ab_cur) * (1 - ab_cur / ab_prev)
noise = torch.randn_like(x)

接下来,我们要把之前算好的所有变量用起来,套入DDIM的去噪均值计算公式中。

也就是(设$\sigma_t^2 = \tilde{\beta}_t$, $\mathbf{z}$为来自标准正态分布的噪声):

由于我们只有噪声$\epsilon$,要把$\mathbf{x}_0=(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon)/\sqrt{\bar{\alpha}_t}$代入,得到不含$\mathbf{x}_0$的公式:

我在代码里把公式的三项分别命名为first_term, second_term, third_term,以便查看。

特别地,当使用DDPM的$\hat{\sigma_t}$方差取值(令$\sigma_t^2=\beta_t=\hat{\sigma_t}^2$)时,不能把这个方差套入公式中,不然$\sqrt{1-{\bar{\alpha}}_{t}-\sigma_t^2}$的根号里的数会小于0。DDIM论文提出的做法是,只修改后面和噪声$\mathbf{z}$有关的方差项,前面这个根号里的方差项保持$\sigma_t^2=(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$ ($\eta=1$)的取值。

当然,上面这些公式全都是在描述$t$到$t-1$。当描述$t_2$到$t_1$时,只需要把$\beta_t$换成$1-\frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}$,再把所有$t$换成$t_2$,$t-1$换成$t_1$即可。

把上面的公式和处理逻辑翻译成代码,就是这样:

1
2
3
4
5
6
7
8
first_term = (ab_prev / ab_cur)**0.5 * x
second_term = ((1 - ab_prev - var)**0.5 -
(ab_prev * (1 - ab_cur) / ab_cur)**0.5) * eps
if simple_var:
third_term = (1 - ab_cur / ab_prev)**0.5 * noise
else:
third_term = var**0.5 * noise
x = first_term + second_term + third_term

这样,下一刻的x就算完了。反复执行循环即可得到最终的结果。

实验

写完了DDIM采样后,我们可以编写一个随机生成图片的函数。由于DDPMDDIM的接口非常相似,我们可以用同一套代码实现DDPM或DDIM的采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def sample_imgs(ddpm,
net,
output_path,
img_shape,
n_sample=64,
device='cuda',
simple_var=True,
to_bgr=False,
**kwargs):
if img_shape[1] >= 256:
max_batch_size = 16
elif img_shape[1] >= 128:
max_batch_size = 64
else:
max_batch_size = 256

net = net.to(device)
net = net.eval()

index = 0
with torch.no_grad():
while n_sample > 0:
if n_sample >= max_batch_size:
batch_size = max_batch_size
else:
batch_size = n_sample
n_sample -= batch_size
shape = (batch_size, *img_shape)
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var,
**kwargs).detach().cpu()
imgs = (imgs + 1) / 2 * 255
imgs = imgs.clamp(0, 255).to(torch.uint8)

img_list = einops.rearrange(imgs, 'n c h w -> n h w c').numpy()
output_dir = os.path.splitext(output_path)[0]
os.makedirs(output_dir, exist_ok=True)
for i, img in enumerate(img_list):
if to_bgr:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(f'{output_dir}/{i+index}.jpg', img)

# First iteration
if index == 0:
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(batch_size**0.5))
imgs = imgs.numpy()
if to_bgr:
imgs = cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, imgs)

index += batch_size

为了生成大量图片以计算FID,在这个函数中我加入了很多和batch有关的处理。剔除这些处理代码以及图像存储后处理代码,和采样有关的核心代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def sample_imgs(ddpm,
net,
output_path,
img_shape,
n_sample=64,
device='cuda',
simple_var=True,
to_bgr=False,
**kwargs):

net = net.to(device)
net = net.eval()

with torch.no_grad():
shape = (n_sample, *img_shape)
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var,
**kwargs).detach().cpu()

如果是用DDPM采样,把参数表里的那些参数填完就行了;如果是DDIM采样,则需要在kwargs里指定ddim_stepeta

使用这个函数,我们可以进行不同ddim_step和不同eta下的64x64 CelebAHQ采样实验,以尝试复现DDIM论文的实验结果。

我们先准备好变量。

1
2
3
4
5
net = UNet(n_steps, img_shape, cfg['channels'], cfg['pe_dim'],
cfg.get('with_attn', False), cfg.get('norm_type', 'ln'))
ddpm = DDPM(device, n_steps)
ddim = DDIM(device, n_steps)
net.load_state_dict(torch.load(model_path))

第一组实验是总时刻保持1000,使用$\hat{\sigma}_t$(标准DDPM)和$\eta=0$(标准DDIM)的实验。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
sample_imgs(ddpm,
net,
'work_dirs/diffusion_ddpm_sigma_hat.jpg',
img_shape,
device=device,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddpm_eta_0.jpg',
img_shape,
device=device,
to_bgr=to_bgr,
ddim_step=1000,
simple_var=False,
eta=0)

把参数n_samples改成30000,就可以生成30000张图像,以和30000张图像的CelebAHQ之间算FID指标。由于总时刻1000的采样速度非常非常慢,建议使用dist_sample.py并行采样。

算FID指标时,可以使用torch fidelity库。使用pip即可安装此库。

1
pip install torch-fidelity

之后就可以使用命令fidelity来算指标了。假设我们把降采样过的CelebAHQ存储在data/celebA/celeba_hq_64,把我们生成的30000张图片存在work_dirs/diffusion_ddpm_sigma_hat,就可以用下面的命令算FID指标。

1
fidelity --gpu 0 --fid --input1 work_dirs/diffusion_ddpm_sigma_hat --input2 data/celebA/celeba_hq_64

整体来看,我的模型比论文差一点,总的FID会高一点。各个配置下的对比结果也稍有出入。在第一组实验中,使用$\hat{\sigma}_t$时,我的FID是13.68;使用$\eta=0$时,我的FID是13.09。而论文中用$\hat{\sigma}_t$时的FID比$\eta=0$时更低。

我们还可以做第二组实验,测试ddim_step=20(我设置的默认步数)时使用$\eta=0$, $\eta=1$, $\hat{\sigma}_t$的生成效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_sigma_hat.jpg',
img_shape,
device=device,
simple_var=True,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_eta_1.jpg',
img_shape,
device=device,
simple_var=False,
eta=1,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_eta_0.jpg',
img_shape,
device=device,
simple_var=False,
eta=0,
to_bgr=to_bgr)

我的FID结果是:

1
2
3
eta=0: 17.80
eta=1: 24.00
sigma hat: 213.16

这里得到的实验结果和论文一致。减少采样迭代次数后,生成质量略有降低。同采样步数下,eta=0最优。使用sigma hat的结果会有非常多的噪声,差得完全不能看。

综合上面两个实验来看,不管什么情况下,使用eta=0,得到的结果都不会太差。

从生成速度上来看,在64x64 CelebAHQ上生成256张图片,ddim_step=20时只要3秒不到,而ddim_step=1000时要200秒。基本上是步数减少到几分之一就提速几倍。可见,DDIM加速采样对于扩散模型来说是必要的。

参考文献及学习提示

如果对DDIM公式推导及其他数学知识感兴趣,欢迎阅读苏剑林的文章:
https://spaces.ac.cn/archives/9181。

DDIM的论文为Denoising diffusion implicit models(https://arxiv.org/abs/2010.02502)。

我在本文使用的公式符号都基于DDPM论文,与上面两篇文章使用的符号不一样。比如DDIM论文里的$\alpha$在本文中是用$\bar{\alpha}$表示。

DDIM论文在介绍新均值公式时很不友好地在3.1节直接不加解释地给出了公式的形式,并在附录B中以先给结论再证明这种和逻辑思维完全反过来的方法介绍了公式的由来。建议去阅读苏剑林的文章,看看是怎么按正常的思考方式正向推导出DDIM公式。

除了在3.1节直接甩给你一个公式外,DDIM论文后面的地方都很好读懂。DDIM后面还介绍了一些比较有趣的内容,比如4.3节介绍了扩散模型和常微分方程的关系,它可以帮助我们理解为什么DDPM会设置$T=1000$这么长的加噪步数。5.3节中作者介绍了如何用DDIM在两幅图像间插值。

要回顾DDPM的知识,欢迎阅读我之前的文章:DDPM详解。

在过去的大半年里,以Stable Diffusion为代表的AI绘画是世界上最为火热的AI方向之一。或许大家会有疑问,Stable Diffusion里的这个”Diffusion”是什么意思?其实,扩散模型(Diffusion Model)正是Stable Diffusion中负责生成图像的模型。想要理解Stable Diffusion的原理,就一定绕不过扩散模型的学习。

Stable Diffusion以「毕加索笔下的《最后的晚餐》」为题的绘画结果

在这篇文章里,我会由浅入深地对最基础的去噪扩散概率模型(Denoising Diffusion Probabilistic Models, DDPM)进行讲解。我会先介绍扩散模型生成图像的基本原理,再用简单的数学语言对扩散模型建模,最后给出扩散模型的一份PyTorch实现。本文不会堆砌过于复杂的数学公式,哪怕你没有相关的数学背景,也能够轻松理解扩散模型的原理。

扩散模型与图像生成

在认识扩散模型之前,我们先退一步,看看一般的神经网络模型是怎么生成图像的。显然,为了生成丰富的图像,一个图像生成程序要根据随机数来生成图像。通常,这种随机数是一个满足标准正态分布的随机向量。这样,每次要生成新图像时,只需要从标准正态分布里随机生成一个向量并输入给程序就行了。

而在AI绘画程序中,负责生成图像的是一个神经网络模型。神经网络需要从数据中学习。对于图像生成任务,神经网络的训练数据一般是一些同类型的图片。比如一个绘制人脸的神经网络会用人脸照片来训练。也就是说,神经网络会学习如何把一个向量映射成一张图片,并确保这个图片和训练集的图片是一类图片。

可是,相比其他AI任务,图像生成任务对神经网络来说更加困难一点——图像生成任务缺乏有效的指导。在其他AI任务中,训练集本身会给出一个「标准答案」,指导AI的输出向标准答案靠拢。比如对于图像分类任务,训练集会给出每一幅图像的类别;对于人脸验证任务,训练集会给出两张人脸照片是不是同一个人;对于目标检测任务,训练集会给出目标的具体位置。然而,图像生成任务是没有标准答案的。图像生成数据集里只有一些同类型图片,却没有指导AI如何画得更好的信息。

为了解决这一问题,人们专门设计了一些用于生成图像的神经网络架构。这些架构中比较出名的有生成对抗模型(GAN)和变分自编码器(VAE)。

GAN的想法是,既然不知道一幅图片好不好,就干脆再训练一个神经网络,用于辨别某图片是不是和训练集里的图片长得一样。生成图像的神经网络叫做生成器,鉴定图像的神经网络叫做判别器。两个网络互相对抗,共同进步。

VAE则使用了逆向思维:学习向量生成图像很困难,那就再同时学习怎么用图像生成向量。这样,把某图像变成向量,再用该向量生成图像,就应该得到一幅和原图像一模一样的图像。每一个向量的绘画结果有了一个标准答案,可以用一般的优化方法来指导网络的训练了。VAE中,把图像变成向量的网络叫做编码器,把向量转换回图像的网络叫做解码器。其中,解码器就是负责生成图像的模型。

一直以来,GAN的生成效果较好,但训练起来比VAE麻烦很多。有没有和GAN一样强大,训练起来又方便的生成网络架构呢?扩散模型正是满足这些要求的生成网络架构。

扩散模型是一种特殊的VAE,其灵感来自于热力学:一个分布可以通过不断地添加噪声变成另一个分布。放到图像生成任务里,就是来自训练集的图像可以通过不断添加噪声变成符合标准正态分布的图像。从这个角度出发,我们可以对VAE做以下修改:1)不再训练一个可学习的编码器,而是把编码过程固定成不断添加噪声的过程;2)不再把图像压缩成更短的向量,而是自始至终都对一个等大的图像做操作。解码器依然是一个可学习的神经网络,它的目的也同样是实现编码的逆操作。不过,既然现在编码过程变成了加噪,那么解码器就应该负责去噪。而对于神经网络来说,去噪任务学习起来会更加有效。因此,扩散模型既不会涉及GAN中复杂的对抗训练,又比VAE更强大一点。

具体来说,扩散模型由正向过程反向过程这两部分组成,对应VAE中的编码和解码。在正向过程中,输入$\mathbf{x}_0$会不断混入高斯噪声。经过$T$次加噪声操作后,图像$\mathbf{x}_T$会变成一幅符合标准正态分布的纯噪声图像。而在反向过程中,我们希望训练出一个神经网络,该网络能够学会$T$个去噪声操作,把$\mathbf{x}_T$还原回$\mathbf{x}_0$。网络的学习目标是让$T$个去噪声操作正好能抵消掉对应的加噪声操作。训练完毕后,只需要从标准正态分布里随机采样出一个噪声,再利用反向过程里的神经网络把该噪声恢复成一幅图像,就能够生成一幅图片了。

高斯噪声,就是一幅各处颜色值都满足高斯分布(正态分布)的噪声图像。

总结一下,图像生成网络会学习如何把一个向量映射成一幅图像。设计网络架构时,最重要的是设计学习目标,让网络生成的图像和给定数据集里的图像相似。VAE的做法是使用两个网络,一个学习把图像编码成向量,另一个学习把向量解码回图像,它们的目标是让复原图像和原图像尽可能相似。学习完毕后,解码器就是图像生成网络。扩散模型是一种更具体的VAE。它把编码过程固定为加噪声,并让解码器学习怎么样消除之前添加的每一步噪声。

扩散模型的具体算法

上一节中,我们只是大概了解扩散模型的整体思想。这一节,我们来引入一些数学表示,来看一看扩散模型的训练算法和采样算法具体是什么。为了便于理解,这一节会出现一些不是那么严谨的数学描述。更加详细的一些数学推导会放到下一节里介绍。

前向过程

在前向过程中,来自训练集的图像$\mathbf{x}_0$会被添加$T$次噪声,使得$x_T$为符合标准正态分布。准确来说,「加噪声」并不是给上一时刻的图像加上噪声值,而是从一个均值与上一时刻图像相关的正态分布里采样出一幅新图像。如下面的公式所示,$\mathbf{x}_{t - 1}$是上一时刻的图像,$\mathbf{x}_{t}$是这一时刻生成的图像,该图像是从一个均值与$\mathbf{x}_{t - 1}$有关的正态分布里采样出来的。

多数文章会说前向过程是一个马尔可夫过程。其实,马尔可夫过程的意思就是当前时刻的状态只由上一时刻的状态决定,而不由更早的状态决定。上面的公式表明,计算$\mathbf{x}_t$,只需要用到$\mathbf{x}_{t - 1}$,而不需要用到$\mathbf{x}_{t - 2}, \mathbf{x}_{t - 3}…$,这符合马尔可夫过程的定义。

绝大多数扩散模型会把这个正态分布设置成这个形式:

这个正态分布公式乍看起来很奇怪:$\sqrt{1 - \beta_t}$是哪里冒出来的?为什么会有这种奇怪的系数?别急,我们先来看另一个问题:假如给定$\mathbf{x}_{0}$,也就是从训练集里采样出一幅图片,该怎么计算任意一个时刻$t$的噪声图像$\mathbf{x}_{t}$呢?

我们不妨按照公式,从$\mathbf{x}_{t}$开始倒推。$\mathbf{x}_{t}$其实可以通过一个标准正态分布的样本$\epsilon_{t-1}$算出来:

再往前推几步:

由正态分布的性质可知,均值相同的正态分布「加」在一起后,方差也会加到一起。也就是$\mathcal{N}(0, \sigma_1^2 I)$与$\mathcal{N}(0, \sigma_2^2 I)$合起来会得到$\mathcal{N}(0, (\sigma_1^2+\sigma_2^2) I)$。根据这一性质,上面的公式可以化简为:

再往前推一步的话,结果是:

我们已经能够猜出规律来了,可以一直把公式推到$\mathbf{x}_{0}$。令$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$,则:

有了这个公式,我们就可以讨论加噪声公式为什么是$\mathbf{x}_t \sim \mathcal{N}(\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$了。这个公式里的$\beta_t$是一个小于1的常数。在DDPM论文中,$\beta_t$从$\beta_1=10^{-4}$到$\beta_T=0.02$线性增长。这样,$\beta_t$变大,$\alpha_t$也越小,$\bar{\alpha}_t$趋于0的速度越来越快。最后,$\bar{\alpha}_T$几乎为0,代入$\mathbf{x}_T = \sqrt{\bar{\alpha}_T}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_T}\epsilon$, $\mathbf{x}_T$就满足标准正态分布了,符合我们对扩散模型的要求。上述推断可以简单描述为:加噪声公式能够从慢到快地改变原图像,让图像最终均值为0,方差为$\mathbf{I}$。

大家不妨尝试一下,设加噪声公式中均值和方差前的系数分别为$a, b$,按照上述过程计算最终分布的方差。只有$a^2 + b^2 = 1$才能保证最后$\mathbf{x}_T$的方差系数为1。

反向过程

在正向过程中,我们人为设置了$T$步加噪声过程。而在反向过程中,我们希望能够倒过来取消每一步加噪声操作,让一幅纯噪声图像变回数据集里的图像。这样,利用这个去噪声过程,我们就可以把任意一个从标准正态分布里采样出来的噪声图像变成一幅和训练数据长得差不多的图像,从而起到图像生成的目的。

现在问题来了:去噪声操作的数学形式是怎么样的?怎么让神经网络来学习它呢?数学原理表明,当$\beta_t$足够小时,每一步加噪声的逆操作也满足正态分布。

其中,当前时刻加噪声逆操作的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$由当前的时刻$t$、当前的图像$\mathbf{x}_{t}$决定。因此,为了描述所有去噪声操作,神经网络应该输入$t$、$\mathbf{x}_{t}$,拟合当前的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$。

不要被上文的「去噪声」、「加噪声逆操作」绕晕了哦。由于加噪声是固定的,加噪声的逆操作也是固定的。理想情况下,我们希望去噪操作就等于加噪声逆操作。然而,加噪声的逆操作不太可能从理论上求得,我们只能用一个神经网络去拟合它。去噪声操作和加噪声逆操作的关系,就是神经网络的预测值和真值的关系。

现在问题来了:加噪声逆操作的均值和方差是什么?

直接计算所有数据的加噪声逆操作的分布是不太现实的。但是,如果给定了某个训练集输入$\mathbf{x}_0$,多了一个限定条件后,该分布是可以用贝叶斯公式计算的(其中$q$表示概率分布):

等式左边的$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1};\tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$表示加噪声操作的逆操作,它的均值和方差都是待求的。右边的$q(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$是加噪声的分布。而由于$\mathbf{x}_0$已知,$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_0)$两项可以根据前面的公式$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t$得来:

这样,等式右边的式子全部已知。我们可以把公式套入,算出给定$\mathbf{x}_0$时的去噪声分布。经计算化简,分布的均值为:

其中,$\epsilon_t$是用公式算$\mathbf{x}_t$时从标准正态分布采样出的样本,它来自公式

分布的方差为:

注意,$\beta_t$是加噪声的方差,是一个常量。那么,加噪声逆操作的方差$\tilde{\beta}_t$也是一个常量,不与输入$\mathbf{x}_0$相关。这下就省事了,训练去噪网络时,神经网络只用拟合$T$个均值就行,不用再拟合方差了。

知道了均值和方差的真值,训练神经网络只差最后的问题了:该怎么设置训练的损失函数?加噪声逆操作和去噪声操作都是正态分布,网络的训练目标应该是让每对正态分布更加接近。那怎么用损失函数描述两个分布尽可能接近呢?最直观的想法,肯定是让两个正态分布的均值尽可能接近,方差尽可能接近。根据上文的分析,方差是常量,只用让均值尽可能接近就可以了。

那怎么用数学公式表达让均值更接近呢?再观察一下目标均值的公式:

神经网络拟合均值时,$\mathbf{x}_{t}$是已知的(别忘了,图像是一步一步倒着去噪的)。式子里唯一不确定的只有$\epsilon_t$。既然如此,神经网络干脆也别预测均值了,直接预测一个噪声$\epsilon_\theta(\mathbf{x}_{t}, t)$(其中$\theta$为可学习参数),让它和生成$\mathbf{x}_{t}$的噪声$\epsilon_t$的均方误差最小就行了。对于一轮训练,最终的误差函数可以写成

这样,我们就认识了反向过程的所有内容。总结一下,反向过程中,神经网络应该让$T$个去噪声操作拟合对应的$T$个加噪声逆操作。每步加噪声逆操作符合正态分布,且在给定某个输入时,该正态分布的均值和方差是可以用解析式表达出来的。因此,神经网络的学习目标就是让其输出的去噪声分布和理论计算的加噪声逆操作分布一致。经过数学计算上的一些化简,问题被转换成了拟合生成$\mathbf{x}_{t}$时用到的随机噪声$\epsilon_t$。

训练算法与采样算法

理解了前向过程和反向过程后,训练神经网络的算法和采样图片(生成图片)的算法就呼之欲出了。

以下是DDPM论文中的训练算法:

让我们来逐行理解一下这个算法。第二行是指从训练集里取一个数据$\mathbf{x}_{0}$。第三行是指随机从$1, …, T$里取一个时刻用来训练。我们虽然要求神经网络拟合$T$个正态分布,但实际训练时,不用一轮预测$T$个结果,只需要随机预测$T$个时刻中某一个时刻的结果就行。第四行指随机生成一个噪声$\epsilon$,该噪声是用于执行前向过程生成$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon$的。之后,我们把$\mathbf{x}_t$和$t$传给神经网络$\epsilon_\theta(\mathbf{x}_{t}, t)$,让神经网络预测随机噪声。训练的损失函数是预测噪声和实际噪声之间的均方误差,对此损失函数采用梯度下降即可优化网络。

DDPM并没有规定神经网络的结构。根据任务的难易程度,我们可以自己定义简单或复杂的网络结构。这里只需要把$\epsilon_\theta(\mathbf{x}_{t}, t)$当成一个普通的映射即可。

训练好了网络后,我们可以执行反向过程,对任意一幅噪声图像去噪,以实现图像生成。这个算法如下:

第一行的$\mathbf{x}_{T}$就是从标准正态分布里随机采样的输入噪声。要生成不同的图像,只需要更换这个噪声。后面的过程就是扩散模型的反向过程。令时刻从$T$到$1$,计算这一时刻去噪声操作的均值和方差,并采样出$\mathbf{x}_{t-1}$。均值是用之前提到的公式计算的:

而方差$\sigma_t^2$的公式有两种选择,两个公式都能产生差不多的结果。实验表明,当$\mathbf{x}_{0}$是特定的某个数据时,用上一节推导出来的方差最好。

而当$\mathbf{x}_{0} \sim \mathcal{N}(0, \mathbf{I})$时,只需要令方差和加噪声时的方差一样即可。

循环执行去噪声操作。最后生成的$\mathbf{x}_{0}$就是生成出来的图像。

特别地,最后一步去噪声是不用加方差项的。为什么呢,观察公式$\sigma_t^2=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t$。当$t=1$时,分子会出现$\bar{\alpha}_{t-1}=\bar{\alpha}_0$这一项。$\bar{\alpha}_t$是一个连乘,理论上$t$是从$1$开始的,在$t=0$时没有定义。但我们可以特别地令连乘的第0项$\bar{\alpha}_0=1$。这样,$t=1$时方差项的分子$1-\bar{\alpha}_{t-1}$为$0$,不用算这一项了。

当然,这一解释从数学上来说是不严谨的。据论文说,这部分的解释可以参见朗之万动力学。

数学推导的补充 (选读)

理解了训练算法和采样算法,我们就算是搞懂了扩散模型,可以去编写代码了。不过,上文的描述省略了一些数学推导的细节。如果对扩散模型更深的原理感兴趣,可以阅读一下本节。

加噪声逆操作均值和方差的推导

上一节,我们根据下面几个式子

一步就给出了$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$的均值和方差。

现在我们来看一下推导均值和方差的思路。

首先,把其他几个式子带入贝叶斯公式的等式右边。

由于多个正态分布的乘积还是一个正态分布,我们知道$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$也可以用一个正态分布公式$\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$表达,它最后一定能写成这种形式:

问题就变成了怎么把开始那个很长的式子化简,算出$\tilde{\mu}_t$和$\tilde{\beta}_t$。

方差$\tilde{\beta}_t$可以从指数函数的系数得来,比较好求。系数为

所以,方差为:

接下来只要关注指数函数的指数部分。指数部分一定是一个关于的$\mathbf{x}_{t-1}$的二次函数,只要化简成$(\mathbf{x}_{t-1}-C)^2$的形式,再除以一下$-2$倍方差,就可以得到均值了。

指数部分为:

$\mathbf{x}_{t-1}$只在前两项里有。把和$\mathbf{x}_{t-1}$有关的项计算化简,可以计算出均值:

回想一下,在去噪声中,神经网络的输入是$\mathbf{x}_{t}$和$t$。也就是说,上式中$\mathbf{x}_{t}$已知,只有$\mathbf{x}_{0}$一个未知量。要算均值,还需要算出$\mathbf{x}_{0}$。$\mathbf{x}_{0}$和$\mathbf{x}_{t}$之间是有一定联系的。$\mathbf{x}_{t}$是$\mathbf{x}_{0}$在正向过程中第$t$步加噪声的结果。而根据正向过程的公式倒推:

把这个$\mathbf{x}_{0}$带入均值公式,均值最后会化简成我们熟悉的形式。

优化目标

上一节,我们只是简单地说神经网络的优化目标是让加噪声和去噪声的均值接近。而让均值接近,就是让生成$\mathbf{x}_t$的噪声$\epsilon_t$更接近。实际上,这个优化目标是经过简化得来的。扩散模型最早的优化目标是有一定的数学意义的。

扩散模型,全称为扩散概率模型(Diffusion Probabilistic Model)。最简单的一类扩散模型,是去噪扩散概率模型(Denoising Diffusion Probabilistic Model),也就是常说的DDPM。DDPM的框架主要是由两篇论文建立起来的。第一篇论文是首次提出扩散模型思想的Deep Unsupervised Learning using Nonequilibrium Thermodynamics。在此基础上,Denoising Diffusion Probabilistic Models对最早的扩散模型做出了一定的简化,让图像生成效果大幅提升,促成了扩散模型的广泛使用。我们上一节看到的公式,全部是简化后的结果。

扩散概率模型的名字之所以有「概率」二字,是因为这个模型是在描述一个系统的概率。准确来说,扩散模型是在描述经反向过程生成出某一项数据的概率。也就是说,扩散模型$p_{\theta}(\mathbf{x}_0)$是一个有着可训练参数$\theta$的模型,它描述了反向过程生成出数据$\mathbf{x}_0$的概率。$p_{\theta}(\mathbf{x}_0)$满足$p_{\theta}(\mathbf{x}_0)=\int p_{\theta}(\mathbf{x}_{0:T})d\mathbf{x}_{1:T}$,其中$p_{\theta}(\mathbf{x}_{0:T})$就是我们熟悉的反向过程,只不过它是以概率计算的形式表达:

我们上一节里见到的优化目标,是让去噪声操作$p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})$和加噪声操作的逆操作$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$尽可能相似。然而,这个描述并不确切。扩散模型原本的目标,是最大化$p_{\theta}(\mathbf{x}_0)$这个概率,其中$\mathbf{x}_0$是来自训练集的数据。换个角度说,给定一个训练集的数据$\mathbf{x}_0$,经过前向过程和反向过程,扩散模型要让复原出$\mathbf{x}_0$的概率尽可能大。这也是我们在本文开头认识VAE时见到的优化目标。

最大化$p_{\theta}(\mathbf{x}_0)$,一般会写成最小化其负对数值,即最小化$-log p_{\theta}(\mathbf{x}_0)$。使用和VAE类似的变分推理,可以把优化目标转换成优化一个叫做变分下界(variational lower bound, VLB)的量。它最终可以写成:

这里的$D_{KL}(P||Q)$表示分布P和Q之间的KL散度。KL散度是衡量两个分布相似度的指标。如果$P, Q$都是正态分布,则它们的KL散度可以由一个简单的公式给出。关于KL散度的知识可以参见我之前的文章:从零理解熵、交叉熵、KL散度。

其中,第一项$D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) || p_\theta(\mathbf{x}_T))$和可学习参数$\theta$无关(因为可学习参数只描述了每一步去噪声操作,也就是只描述了$p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})$),可以不去管它。那么这个优化目标就由两部分组成:

  1. 最小化$D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))$表示的是最大化每一个去噪声操作和加噪声逆操作的相似度。
  2. 最小化$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$就是已知$\mathbf{x}_{1}$时,让最后复原原图$\mathbf{x}_{0}$概率更高。

我们分别看这两部分是怎么计算的。

对于第一部分,我们先回顾一下正态分布之间的KL散度公式。设一维正态分布$P, Q$的公式如下:

而对于$D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))$,根据前文的分析,我们知道,待求方差$\Sigma_\theta(\mathbf{x}_{t}, t)$可以直接由计算得到。

两个正态分布方差的比值是常量。所以,在计算KL散度时,不用管方差那一项了,只需要管均值那一项。

由根据之前的均值公式

这一部分的优化目标可以化简成

DDPM论文指出,如果把前面的系数全部丢掉的话,模型的效果更好。最终,我们就能得到一个非常简单的优化目标:

这就是我们上一节见到的优化目标。

当然,还没完,别忘了优化目标里还有$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$这一项。它的形式为:

只管后面有$\theta$的那一项(注意,$\alpha_1=\bar{\alpha}_1=1-\beta_1$):

这和那些KL散度项$t=1$时的形式相同,我们可以用相同的方式简化优化目标,只保留$|| \epsilon_1-\epsilon_\theta(\mathbf{x}_{1}, 1)||^2$。这样,损失函数的形式全都是$||\epsilon_t-\epsilon_{\theta}(\mathbf{x}_{t}, t)||^2$了。

DDPM论文里写$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$这一项可以直接满足简化后的公式$t=1$时的情况,而没有去掉系数的过程。我在网上没找到文章解释这一点,只好按自己的理解来推导这个误差项了。不论如何,推导的过程不是那么重要,重要的是最后的简化形式。

总结

图像生成任务就是把随机生成的向量(噪声)映射成和训练图像类似的图像。为此,扩散模型把这个过程看成是对纯噪声图像的去噪过程。通过学习把图像逐步变成纯噪声的逆操作,扩散模型可以把任何一个纯噪声图像变成有意义的图像,也就是完成图像生成。

对于不同程度的读者,应该对本文有不同的认识。

对于只想了解扩散模型大概原理的读者,只需要阅读第一节,并大概了解:

  • 图像生成任务的通常做法
  • 图像生成任务需要监督
  • VAE通过把图像编码再解码来训练一个解码器
  • 扩散模型是一类特殊的VAE,它的编码固定为加噪声,解码固定为去噪声

对于想认真学习扩散模型的读者,只需读懂第二节的主要内容:

  • 扩散模型的优化目标:让反向过程尽可能成为正向过程的逆操作
  • 正向过程的公式
  • 反向过程的做法(采样算法)
  • 加噪声逆操作的均值和方差在给定$\mathbf{x}_{0}$时可以求出来的,加噪声逆操作的均值就是去噪声的学习目标
  • 简化后的损失函数与训练算法

对有学有余力对数学感兴趣的读者,可以看一看第三节的内容:

  • 加噪声逆操作均值和方差的推导
  • 扩散模型最早的优化目标与DDPM论文是如何简化优化目标的

我个人认为,由于扩散模型的优化目标已经被大幅度简化,除非你的研究目标是改进扩散模型本身,否则没必要花过多的时间钻研数学原理。在学习时,建议快点看懂扩散模型的整体思想,搞懂最核心的训练算法和采样算法,跑通代码。之后就可以去看较新的论文了。

在附录中,我给出了一份DDPM的简单实现。欢迎大家参考,并自己动手复现一遍DDPM。

参考资料与学习建议

网上绝大多数的中英文教程都是照搬 https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ 这篇文章的。这篇文章像教科书一样严谨,适合有一定数学基础的人阅读,但不适合给初学者学习。建议在弄懂扩散模型的大概原理后再来阅读这篇文章补充细节。

多数介绍扩散模型的文章对没学过相关数学知识的人来说很不友好,我在阅读此类文章时碰到了大量的问题:为什么前向公式里有个$\sqrt{1-\beta}$?为什么突然冒出一个快速算$\mathbf{x}_{t}$的公式?为什么反向过程里来了个贝叶斯公式?优化目标是什么?$-log p_{\theta}(\mathbf{x}_0)$是什么?为什么优化目标里一大堆项,每一项的意义又是什么?为什么最后莫名其妙算一个$\epsilon$?为什么采样时$t=0$就不用加方差项了?好不容易,我才把这些问题慢慢搞懂,并在本文做出了解释。希望我的解答能够帮助到同样有这些困惑的读者。想逐步学习扩散模型,可以先看懂我这篇文章的大概讲解,再去其他文章里学懂一些细节。无论是教,还是学,最重要的都是搞懂整体思路,知道动机,最后再去强调细节。

再强烈推荐一位作者写的DDPM系列介绍:https://kexue.fm/archives/9119 。这位作者是全网为数不多的能令我敬佩的作者。早知道有这些文章,我也没必要自己写一遍了。

这里还有篇文章给出了扩散模型中数学公式的详细推导,并补充了变分推理的背景介绍,适合从头学起:https://arxiv.org/abs/2208.11970

想深入学习DDPM,可以看一看最重要的两篇论文:Deep Unsupervised Learning using Nonequilibrium ThermodynamicsDenoising Diffusion Probabilistic Models。当然,后者更重要一些,里面的一些实验结果仍有阅读价值。

我在代码复现时参考了这篇文章。相对于网上的其他开源DDPM实现,这份代码比较简短易懂,更适合学习。不过,这份代码有一点问题。它的神经网络不够强大,采样结果会有一点问题。

附录:代码复现

在这个项目中,我们要用PyTorch实现一个基于U-Net的DDPM,并在MNIST数据集(经典的手写数字数据集)上训练它。模型几分钟就能训练完,我们可以方便地做各种各样的实验。

后续讲解只会给出代码片段,完整的代码请参见 https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/ddpm 。git clone 仓库并安装后,可以直接运行目录里的main.py训练模型并采样。

获取数据集

PyTorch的torchvision提供了获取了MNIST的接口,我们只需要用下面的函数就可以生成MNIST的Dataset实例。参数中,root为数据集的下载路径,download为是否自动下载数据集。令download=True的话,第一次调用该函数时会自动下载数据集,而第二次之后就不用下载了,函数会读取存储在root里的数据。

1
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)

我们可以用下面的代码来下载MNIST并输出该数据集的一些信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torchvision
from torchvision.transforms import ToTensor
def download_dataset():
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp.jpg')
tensor = ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

if __name__ == '__main__':
download_dataset()

执行这段代码,输出大致为:

1
2
3
4
5
6
length of MNIST 60000
<PIL.Image.Image image mode=L size=28x28 at 0x7FB3F09CCE50>
9
torch.Size([1, 28, 28])
tensor(1.)
tensor(0.)

第一行输出表明,MNIST数据集里有60000张图片。而从第二行和第三行输出中,我们发现每一项数据由图片和标签组成,图片是大小为28x28的PIL格式的图片,标签表明该图片是哪个数字。我们可以用torchvision里的ToTensor()把PIL图片转成PyTorch张量,进一步查看图片的信息。最后三行输出表明,每一张图片都是单通道图片(灰度图),颜色值的取值范围是0~1。

我们可以查看一下每张图片的样子。如果你是在用带显示器的电脑,可以去掉img.show那一行的注释,直接查看图片;如果你是在用服务器,可以去img.save的路径里查看图片。该图片的应该长这个样子:

我们可以用下面的代码预处理数据并创建DataLoader。由于DDPM会把图像和正态分布关联起来,我们更希望图像颜色值的取值范围是[-1, 1]。为此,我们可以对图像做一个线性变换,减0.5再乘2。

1
2
3
4
5
6
def get_dataloader(batch_size: int):
transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)])
dataset = torchvision.datasets.MNIST(root='data/mnist',
transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)

DDPM 类

在代码中,我们要实现一个DDPM类。它维护了扩散过程中的一些常量(比如$\alpha$),并且可以计算正向过程和反向过程的结果。

先来实现一下DDPM类的初始化函数。一开始,我们遵从论文的配置,用torch.linspace(min_beta, max_beta, n_steps)min_betamax_beta线性地生成n_steps个时刻的$\beta$。接着,我们根据公式$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$,计算每个时刻的alphaalpha_bar。注意,为了方便实现,我们让t的取值从0开始,要比论文里的$t$少1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

class DDPM():

# n_steps 就是论文里的 T
def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
alphas = 1 - betas
alpha_bars = torch.empty_like(alphas)
product = 1
for i, alpha in enumerate(alphas):
product *= alpha
alpha_bars[i] = product
self.betas = betas
self.n_steps = n_steps
self.alphas = alphas
self.alpha_bars = alpha_bars

部分实现会让 DDPM 继承torch.nn.Module,但我认为这样不好。DDPM本身不是一个神经网络,它只是描述了前向过程和后向过程的一些计算。只有涉及可学习参数的神经网络类才应该继承 torch.nn.Module

准备好了变量后,我们可以来实现DDPM类的其他方法。先实现正向过程方法,该方法会根据公式$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t$计算正向过程中的$\mathbf{x}_t$。

1
2
3
4
5
6
def sample_forward(self, x, t, eps=None):
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
if eps is None:
eps = torch.randn_like(x)
res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
return res

这里要解释一些PyTorch编程上的细节。这份代码中,self.alpha_bars是一个一维Tensor。而在并行训练中,我们一般会令t为一个形状为(batch_size, )Tensor。PyTorch允许我们直接用self.alpha_bars[t]self.alpha_bars里取出batch_size个数,就像用一个普通的整型索引来从数组中取出一个数一样。有些实现会用torch.gatherself.alpha_bars里取数,其作用是一样的。

我们可以随机从训练集取图片做测试,看看它们在前向过程中是怎么逐步变成噪声的。

接下来实现反向过程。在反向过程中,DDPM会用神经网络预测每一轮去噪的均值,把$\mathbf{x}_t$复原回$\mathbf{x}_0$,以完成图像生成。反向过程即对应论文中的采样算法。

其实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sample_backward(self, img_shape, net, device, simple_var=True):
x = torch.randn(img_shape).to(device)
net = net.to(device)
for t in range(self.n_steps - 1, -1, -1):
x = self.sample_backward_step(x, t, net, simple_var)
return x

def sample_backward_step(self, x_t, t, net, simple_var=True):
n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

其中,sample_backward是用来给外部调用的方法,而sample_backward_step是执行一步反向过程的方法。

sample_backward会随机生成纯噪声x(对应$\mathbf{x}_T$),再令tn_steps - 10,调用sample_backward_step

1
2
3
4
5
6
def sample_backward(self, img_shape, net, device, simple_var=True):
x = torch.randn(img_shape).to(device)
net = net.to(device)
for t in range(self.n_steps - 1, -1, -1):
x = self.sample_backward_step(x, t, net, simple_var)
return x

sample_backward_step中,我们先准备好这一步的神经网络输出eps。为此,我们要把整型的t转换成一个格式正确的Tensor。考虑到输入里可能有多个batch,我们先获取batch size n,再根据它来生成t_tensor

1
2
3
4
5
6
def sample_backward_step(self, x_t, t, net, simple_var=True):

n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

之后,我们来处理反向过程公式中的方差项。根据伪代码,我们仅在t非零的时候算方差项。方差项用到的方差有两种取值,效果差不多,我们用simple_var来控制选哪种取值方式。获取方差后,我们再随机采样一个噪声,根据公式,得到方差项。

1
2
3
4
5
6
7
8
9
10
if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

最后,我们把eps和方差项套入公式,得到这一步更新过后的图像x_t

1
2
3
4
5
6
mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

稍后完成了训练后,我们再来看反向过程的输出结果。

训练算法

接下来,我们先跳过神经网络的实现,直接完成论文里的训练算法。

再回顾一遍伪代码。首先,我们要随机选取训练图片$\mathbf{x}_{0}$,随机生成当前要训练的时刻$t$,以及随机生成一个生成$\mathbf{x}_{t}$的高斯噪声。之后,我们把$\mathbf{x}_{t}$和$t$输入进神经网络,尝试预测噪声。最后,我们以预测噪声和实际噪声的均方误差为损失函数做梯度下降。

为此,我们可以用下面的代码实现训练。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from dldemos.ddpm.dataset import get_dataloader, get_img_shape
from dldemos.ddpm.ddpm import DDPM
import cv2
import numpy as np
import einops

batch_size = 512
n_epochs = 100


def train(ddpm: DDPM, net, device, ckpt_path):
# n_steps 就是公式里的 T
# net 是某个继承自 torch.nn.Module 的神经网络
n_steps = ddpm.n_steps
dataloader = get_dataloader(batch_size)
net = net.to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)

for e in range(n_epochs):
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))
loss = loss_fn(eps_theta, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net.state_dict(), ckpt_path)

代码的主要逻辑都在循环里。首先是完成训练数据$\mathbf{x}_{0}$、$t$、噪声的采样。采样$\mathbf{x}_{0}$的工作可以交给PyTorch的DataLoader完成,每轮遍历得到的x就是训练数据。$t$的采样可以用torch.randint函数随机从[0, n_steps - 1]取数。采样高斯噪声可以直接用torch.randn_like(x)生成一个和训练图片x形状一样的符合标准正态分布的图像。

1
2
3
4
5
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)

之后计算$\mathbf{x}_{t}$并将其和$t$输入进神经网络net。计算$\mathbf{x}_{t}$的任务会由DDPM类的sample_forward方法完成,我们在上文已经实现了它。

1
2
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))

得到了预测的噪声eps_theta,我们调用PyTorch的API,算均方误差并调用优化器即可。

1
2
3
4
5
6
7
8
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)

...
loss = loss_fn(eps_theta, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()

去噪神经网络

在DDPM中,理论上我们可以用任意一种神经网络架构。但由于DDPM任务十分接近图像去噪任务,而U-Net又是去噪任务中最常见的网络架构,因此绝大多数DDPM都会使用基于U-Net的神经网络。

我一直想训练一个尽可能简单的模型。经过多次实验,我发现DDPM的神经网络很难训练。哪怕是对于比较简单的MNIST数据集,结构差一点的网络(比如纯ResNet)都不太行,只有带了残差块和时序编码的U-Net才能较好地完成去噪。注意力模块倒是可以不用加上。

由于神经网络结构并不是DDPM学习的重点,我这里就不对U-Net的写法做解说,而是直接贴上代码了。代码中大部分内容都和普通的U-Net无异。唯一要注意的地方就是时序编码。去噪网络的输入除了图像外,还有一个时间戳t。我们要考虑怎么把t的信息和输入图像信息融合起来。大部分人的做法是对t进行Transformer中的位置编码,把该编码加到图像的每一处上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import torch
import torch.nn as nn
import torch.nn.functional as F
from dldemos.ddpm.dataset import get_img_shape


class PositionalEncoding(nn.Module):

def __init__(self, max_seq_len: int, d_model: int):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

pe = torch.zeros(max_seq_len, d_model)
i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(max_seq_len, d_model)

self.embedding = nn.Embedding(max_seq_len, d_model)
self.embedding.weight.data = pe
self.embedding.requires_grad_(False)

def forward(self, t):
return self.embedding(t)


class ResidualBlock(nn.Module):

def __init__(self, in_c: int, out_c: int):
super().__init__()
self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(out_c)
self.actvation1 = nn.ReLU()
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(out_c)
self.actvation2 = nn.ReLU()
if in_c != out_c:
self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, 1),
nn.BatchNorm2d(out_c))
else:
self.shortcut = nn.Identity()

def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.actvation1(x)
x = self.conv2(x)
x = self.bn2(x)
x += self.shortcut(input)
x = self.actvation2(x)
return x


class ConvNet(nn.Module):

def __init__(self,
n_steps,
intermediate_channels=[10, 20, 40],
pe_dim=10,
insert_t_to_all_layers=False):
super().__init__()
C, H, W = get_img_shape() # 1, 28, 28
self.pe = PositionalEncoding(n_steps, pe_dim)

self.pe_linears = nn.ModuleList()
self.all_t = insert_t_to_all_layers
if not insert_t_to_all_layers:
self.pe_linears.append(nn.Linear(pe_dim, C))

self.residual_blocks = nn.ModuleList()
prev_channel = C
for channel in intermediate_channels:
self.residual_blocks.append(ResidualBlock(prev_channel, channel))
if insert_t_to_all_layers:
self.pe_linears.append(nn.Linear(pe_dim, prev_channel))
else:
self.pe_linears.append(None)
prev_channel = channel
self.output_layer = nn.Conv2d(prev_channel, C, 3, 1, 1)

def forward(self, x, t):
n = t.shape[0]
t = self.pe(t)
for m_x, m_t in zip(self.residual_blocks, self.pe_linears):
if m_t is not None:
pe = m_t(t).reshape(n, -1, 1, 1)
x = x + pe
x = m_x(x)
x = self.output_layer(x)
return x


class UnetBlock(nn.Module):

def __init__(self, shape, in_c, out_c, residual=False):
super().__init__()
self.ln = nn.LayerNorm(shape)
self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.activation = nn.ReLU()
self.residual = residual
if residual:
if in_c == out_c:
self.residual_conv = nn.Identity()
else:
self.residual_conv = nn.Conv2d(in_c, out_c, 1)

def forward(self, x):
out = self.ln(x)
out = self.conv1(out)
out = self.activation(out)
out = self.conv2(out)
if self.residual:
out += self.residual_conv(x)
out = self.activation(out)
return out


class UNet(nn.Module):

def __init__(self,
n_steps,
channels=[10, 20, 40, 80],
pe_dim=10,
residual=False) -> None:
super().__init__()
C, H, W = get_img_shape()
layers = len(channels)
Hs = [H]
Ws = [W]
cH = H
cW = W
for _ in range(layers - 1):
cH //= 2
cW //= 2
Hs.append(cH)
Ws.append(cW)

self.pe = PositionalEncoding(n_steps, pe_dim)

self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.pe_linears_en = nn.ModuleList()
self.pe_linears_de = nn.ModuleList()
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
prev_channel = C
for channel, cH, cW in zip(channels[0:-1], Hs[0:-1], Ws[0:-1]):
self.pe_linears_en.append(
nn.Sequential(nn.Linear(pe_dim, prev_channel), nn.ReLU(),
nn.Linear(prev_channel, prev_channel)))
self.encoders.append(
nn.Sequential(
UnetBlock((prev_channel, cH, cW),
prev_channel,
channel,
residual=residual),
UnetBlock((channel, cH, cW),
channel,
channel,
residual=residual)))
self.downs.append(nn.Conv2d(channel, channel, 2, 2))
prev_channel = channel

self.pe_mid = nn.Linear(pe_dim, prev_channel)
channel = channels[-1]
self.mid = nn.Sequential(
UnetBlock((prev_channel, Hs[-1], Ws[-1]),
prev_channel,
channel,
residual=residual),
UnetBlock((channel, Hs[-1], Ws[-1]),
channel,
channel,
residual=residual),
)
prev_channel = channel
for channel, cH, cW in zip(channels[-2::-1], Hs[-2::-1], Ws[-2::-1]):
self.pe_linears_de.append(nn.Linear(pe_dim, prev_channel))
self.ups.append(nn.ConvTranspose2d(prev_channel, channel, 2, 2))
self.decoders.append(
nn.Sequential(
UnetBlock((channel * 2, cH, cW),
channel * 2,
channel,
residual=residual),
UnetBlock((channel, cH, cW),
channel,
channel,
residual=residual)))

prev_channel = channel

self.conv_out = nn.Conv2d(prev_channel, C, 3, 1, 1)

def forward(self, x, t):
n = t.shape[0]
t = self.pe(t)
encoder_outs = []
for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders,
self.downs):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = encoder(x + pe)
encoder_outs.append(x)
x = down(x)
pe = self.pe_mid(t).reshape(n, -1, 1, 1)
x = self.mid(x + pe)
for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de,
self.decoders, self.ups,
encoder_outs[::-1]):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = up(x)

pad_x = encoder_out.shape[2] - x.shape[2]
pad_y = encoder_out.shape[3] - x.shape[3]
x = F.pad(x, (pad_x // 2, pad_x - pad_x // 2, pad_y // 2,
pad_y - pad_y // 2))
x = torch.cat((encoder_out, x), dim=1)
x = decoder(x + pe)
x = self.conv_out(x)
return x


convnet_small_cfg = {
'type': 'ConvNet',
'intermediate_channels': [10, 20],
'pe_dim': 128
}

convnet_medium_cfg = {
'type': 'ConvNet',
'intermediate_channels': [10, 10, 20, 20, 40, 40, 80, 80],
'pe_dim': 256,
'insert_t_to_all_layers': True
}
convnet_big_cfg = {
'type': 'ConvNet',
'intermediate_channels': [20, 20, 40, 40, 80, 80, 160, 160],
'pe_dim': 256,
'insert_t_to_all_layers': True
}

unet_1_cfg = {'type': 'UNet', 'channels': [10, 20, 40, 80], 'pe_dim': 128}
unet_res_cfg = {
'type': 'UNet',
'channels': [10, 20, 40, 80],
'pe_dim': 128,
'residual': True
}


def build_network(config: dict, n_steps):
network_type = config.pop('type')
if network_type == 'ConvNet':
network_cls = ConvNet
elif network_type == 'UNet':
network_cls = UNet

network = network_cls(n_steps, **config)
return network

实验结果与采样

把之前的所有代码综合一下,我们以带残差块的U-Net为去噪网络,执行训练。

1
2
3
4
5
6
7
8
9
10
11
if __name__ == '__main__':
n_steps = 1000
config_id = 4
device = 'cuda'
model_path = 'dldemos/ddpm/model_unet_res.pth'

config = unet_res_cfg
net = build_network(config, n_steps)
ddpm = DDPM(device, n_steps)

train(ddpm, net, device=device, ckpt_path=model_path)

按照默认训练配置,在3090上花5分钟不到,训练30~40个epoch即可让网络基本收敛。最终收敛时loss在0.023~0.024左右。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
batch size: 512
epoch 0 loss: 0.23103461712201437 elapsed 7.01s
epoch 1 loss: 0.0627968365987142 elapsed 13.66s
epoch 2 loss: 0.04828845852613449 elapsed 20.25s
epoch 3 loss: 0.04148937337398529 elapsed 26.80s
epoch 4 loss: 0.03801360730528831 elapsed 33.37s
epoch 5 loss: 0.03604260584712028 elapsed 39.96s
epoch 6 loss: 0.03357676289876302 elapsed 46.57s
epoch 7 loss: 0.0335664684087038 elapsed 53.15s
...
epoch 30 loss: 0.026149748386939366 elapsed 204.64s
epoch 31 loss: 0.025854381563266117 elapsed 211.24s
epoch 32 loss: 0.02589433005253474 elapsed 217.84s
epoch 33 loss: 0.026276464049021404 elapsed 224.41s
...
epoch 96 loss: 0.023299352884292603 elapsed 640.25s
epoch 97 loss: 0.023460942271351815 elapsed 646.90s
epoch 98 loss: 0.023584651704629263 elapsed 653.54s
epoch 99 loss: 0.02364126600921154 elapsed 660.22s

训练这个网络时,并没有特别好的测试指标,我们只能通过观察采样图像来评价网络的表现。我们可以用下面的代码调用DDPM的反向传播方法,生成多幅图像并保存下来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def sample_imgs(ddpm,
net,
output_path,
n_sample=81,
device='cuda',
simple_var=True):
net = net.to(device)
net = net.eval()
with torch.no_grad():
shape = (n_sample, *get_img_shape()) # 1, 3, 28, 28
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var).detach().cpu()
imgs = (imgs + 1) / 2 * 255
imgs = imgs.clamp(0, 255)
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(n_sample**0.5))

imgs = imgs.numpy().astype(np.uint8)

cv2.imwrite(output_path, imgs)

一切顺利的话,我们可以得到一些不错的生成结果。下图是我得到的一些生成图片:

大部分生成的图片都对应一个阿拉伯数字,它们和训练集MNIST里的图片非常接近。这算是一个不错的生成结果。

如果神经网络的拟合能力较弱,生成结果就会差很多。下图是我训练一个简单的ResNet后得到的采样结果:

可以看出,每幅图片都很乱,基本对应不上一个数字。这就是一个较差的训练结果。

如果网络再差一点,可能会生成纯黑或者纯白的图片。这是因为网络的预测结果不准,在反向过程中,图像的均值不断偏移,偏移到远大于1或者远小于-1的值了。

总结一下,在复现DDPM时,最主要是要学习DDPM论文的两个算法,即训练算法和采样算法。两个算法很简单,可以轻松地把它们翻译成代码。而为了成功完成复现,还需要花一点心思在编写U-Net上,尤其是注意处理时间戳的部分。

前段时间我写了一篇VQVAE的解读,现在再补充一篇VQVAE的PyTorch实现教程。在这个项目中,我们会实现VQVAE论文,在MNIST和CelebAHQ两个数据集上完成图像生成。具体来说,我们会先实现并训练一个图像压缩网络VQVAE,它能把真实图像编码成压缩图像,或者把压缩图像解码回真实图像。之后,我们会训练一个生成压缩图像的生成网络PixelCNN。

代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/VQVAE

项目运行示例

如果你只是想快速地把项目运行起来,可以只阅读本节。

在本地安装好项目后,运行python dldemos/VQVAE/dataset.py来下载MNIST数据集。之后运行python dldemos/VQVAE/main.py,这个脚本会完成以下四个任务:

  1. 训练VQVAE
  2. 用VQVAE重建数据集里的随机数据
  3. 训练PixelCNN
  4. 用PixelCNN+VQVAE随机生成图片

第二步得到的重建结果大致如下(每对图片中左图是原图,右图是重建结果):

第四步得到的随机生成结果大致如下:

如果你要使用CelebAHQ数据集,请照着下一节的指示把CelebAHQ下载到指定目录,再执行python dldemos/VQVAE/main.py -c 4

数据集准备

MNIST数据集可以用PyTorch的API自动下载。我们可以用下面的代码下载MNIST数据集并查看数据的格式。从输出中可知,MNIST的图片形状为[1, 28, 28],颜色取值范围为[0, 1]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def download_mnist():
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp_mnist.jpg')
tensor = transforms.ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

我们可以用下面的代码把它封成简单的Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class MNISTImageDataset(Dataset):

def __init__(self, img_shape=(28, 28)):
super().__init__()
self.img_shape = img_shape
self.mnist = torchvision.datasets.MNIST(root='data/mnist')

def __len__(self):
return len(self.mnist)

def __getitem__(self, index: int):
img = self.mnist[index][0]
pipeline = transforms.Compose(
[transforms.Resize(self.img_shape),
transforms.ToTensor()])
return pipeline(img)

接下来准备CelebAHQ。CelebAHQ数据集原本的图像大小是1024x1024,但我们这个项目用不到这么大的图片。我在kaggle上找到了一个256x256的CelebAHQ (https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256),所有文件加起来只有300MB左右,很适合我们项目。请在该页面下载压缩包,并把压缩包解压到项目的`data/celebA/celeba_hq_256`目录下。

下载完数据后,我们可以写一个简单的从目录中读取图片的Dataset类。和MNIST的预处理流程不同,我这里给CelebAHQ的图片加了一个中心裁剪的操作,一来可以让人脸占比更大,便于模型学习,二来可以让该类兼容CelebA数据集(CelebA数据集的图片不是正方形,需要裁剪)。这个操作是可选的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class CelebADataset(Dataset):

def __init__(self, root, img_shape=(64, 64)):
super().__init__()
self.root = root
self.img_shape = img_shape
self.filenames = sorted(os.listdir(root))

def __len__(self) -> int:
return len(self.filenames)

def __getitem__(self, index: int):
path = os.path.join(self.root, self.filenames[index])
img = Image.open(path)
pipeline = transforms.Compose([
transforms.CenterCrop(168),
transforms.Resize(self.img_shape),
transforms.ToTensor()
])
return pipeline(img)

有了数据集类后,我们可以用它们生成Dataloader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
CELEBA_DIR = 'data/celebA/img_align_celeba'
CELEBA_HQ_DIR = 'data/celebA/celeba_hq_256'
def get_dataloader(type,
batch_size,
img_shape=None,
dist_train=False,
num_workers=4,
**kwargs):
if type == 'CelebA':
if img_shape is not None:
kwargs['img_shape'] = img_shape
dataset = CelebADataset(CELEBA_DIR, **kwargs)
elif type == 'CelebAHQ':
if img_shape is not None:
kwargs['img_shape'] = img_shape
dataset = CelebADataset(CELEBA_HQ_DIR, **kwargs)
elif type == 'MNIST':
if img_shape is not None:
dataset = MNISTImageDataset(img_shape)
else:
dataset = MNISTImageDataset()
if dist_train:
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers)
return dataloader, sampler
else:
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return dataloader

我们可以利用Dataloader来查看CelebAHQ数据集的内容及数据格式。

1
2
3
4
5
6
7
8
9
10
11
12
13
if os.path.exists(CELEBA_HQ_DIR):
dataloader = get_dataloader('CelebAHQ', 16)
img = next(iter(dataloader))
print(img.shape)
N = img.shape[0]
img = einops.rearrange(img,
'(n1 n2) c h w -> c (n1 h) (n2 w)',
n1=int(N**0.5))
print(img.shape)
print(img.max())
print(img.min())
img = transforms.ToPILImage()(img)
img.save('work_dirs/tmp_celebahq.jpg')

从输出中可知,CelebAHQ的颜色取值范围同样是[0, 1]。经我们的预处理流水线得到的图片如下。

实现并训练 VQVAE

要用VQVAE做图像生成,其实要训练两个模型:一个是用于压缩图像的VQVAE,另一个是生成压缩图像的PixelCNN。这两个模型是可以分开训练的。我们先来实现并训练VQVAE。

VQVAE的架构非常简单:一个编码器,一个解码器,外加中间一个嵌入层。损失函数为图像的重建误差与编码器输出与其对应嵌入之间的误差。

VQVAE的编码器和解码器的结构也很简单,仅由普通的上/下采样层和残差块组成。具体来说,编码器先是有两个3x3卷积+2倍下采样卷积的模块,再有两个残差块(ReLU, 3x3卷积, ReLU, 1x1卷积);解码器则反过来,先有两个残差块,再有两个3x3卷积+2倍上采样反卷积的模块。为了让代码看起来更清楚一点,我们不用过度封装,仅实现一个残差块模块,再用残差块和PyTorch自带模块拼成VQVAE。

先实现残差块。注意,由于模型比较简单,残差块内部和VQVAE其他地方都可以不使用BatchNorm。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class ResidualBlock(nn.Module):

def __init__(self, dim):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)

def forward(self, x):
tmp = self.relu(x)
tmp = self.conv1(tmp)
tmp = self.relu(tmp)
tmp = self.conv2(tmp)
return x + tmp

有了残差块类后,我们可以直接实现VQVAE类。我们先在初始化函数里把模块按顺序搭好。编码器和解码器的结构按前文的描述搭起来即可。嵌入空间(codebook)其实就是个普通的嵌入层。此处我仿照他人代码给嵌入层显式初始化参数,但实测下来和默认的初始化参数方式差别不大。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class VQVAE(nn.Module):

def __init__(self, input_dim, dim, n_embedding):
super().__init__()
self.encoder = nn.Sequential(nn.Conv2d(input_dim, dim, 4, 2, 1),
nn.ReLU(), nn.Conv2d(dim, dim, 4, 2, 1),
nn.ReLU(), nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim), ResidualBlock(dim))
self.vq_embedding = nn.Embedding(n_embedding, dim)
self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding,
1.0 / n_embedding)
self.decoder = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim), ResidualBlock(dim),
nn.ConvTranspose2d(dim, dim, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1))
self.n_downsample = 2

之后,我们来实现模型的前向传播。这里的逻辑就略显复杂了。整体来看,这个函数完成了编码、取最近邻、解码这三步。其中,取最近邻的部分最为复杂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def forward(self, x):
# encode
ze = self.encoder(x)

# ze: [N, C, H, W]
# embedding [K, C]
embedding = self.vq_embedding.weight.data
N, C, H, W = ze.shape
K, _ = embedding.shape
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
nearest_neighbor = torch.argmin(distance, 1)
# make C to the second dim
zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)
# stop gradient
decoder_input = ze + (zq - ze).detach()

# decode
x_hat = self.decoder(decoder_input)
return x_hat, ze, zq

我们来详细看一看取最近邻的实现。取最近邻时,我们要用到两块数据:编码器输出ze与嵌入矩阵embeddingze可以看成一个形状为[N, H, W]的数组,数组存储了长度为C的向量。而嵌入矩阵里有K个长度为C的向量。
1
2
3
4
5
# ze: [N, C, H, W]
# embedding [K, C]
embedding = self.vq_embedding.weight.data
N, C, H, W = ze.shape
K, _ = embedding.shape

为了求N*H*W个向量在嵌入矩阵里的最近邻,我们要先算这每个向量与嵌入矩阵里K个向量的距离。在算距离前,我们要把embeddingze的形状变换一下,保证(embedding_broadcast - ze_broadcast)**2的形状为[N, K, C, H, W]。我们对这个临时结果的第2号维度(C所在维度)求和,得到形状为[N, K, H, W]distance。它的含义是,对于N*H*W个向量,每个向量到嵌入空间里K个向量的距离分别是多少。
1
2
3
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)

有了距离张量后,我们再对其1号维度(K所在维度)求最近邻所在下标。

1
nearest_neighbor = torch.argmin(distance, 1)

有了下标后,我们可以用self.vq_embedding(nearest_neighbor)从嵌入空间取出最近邻了。别忘了,nearest_neighbor的形状是[N, H, W]self.vq_embedding(nearest_neighbor)的形状会是[N, H, W, C]。我们还要把C维度转置一下。

1
2
# make C to the second dim
zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)

最后,我们用论文里提到的停止梯度算子,把zq变形一下。这样,算误差的时候用的是zq,算梯度时ze会接收解码器传来的梯度。

1
2
# stop gradient
decoder_input = ze + (zq - ze).detach()

求最近邻的部分就到此结束了。最后再补充一句,前向传播函数不仅返回了重建结果x_hat,还返回了ze, zq。这是因为我们待会要在训练时根据ze, zq求损失函数。

准备好了模型类后,假设我们已经用某些超参数初始化好了模型model,我们可以用下面的代码训练VQVAE。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def train_vqvae(model: VQVAE,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/model.pth',
batch_size=64,
dataset_type='MNIST',
lr=1e-3,
n_epochs=100,
l_w_embedding=1,
l_w_commitment=0.25):
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr)
mse_loss = nn.MSELoss()
tic = time.time()
for e in range(n_epochs):
total_loss = 0

for x in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)

x_hat, ze, zq = model(x)
l_reconstruct = mse_loss(x, x_hat)
l_embedding = mse_loss(ze.detach(), zq)
l_commitment = mse_loss(ze, zq.detach())
loss = l_reconstruct + \
l_w_embedding * l_embedding + l_w_commitment * l_commitment
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), ckpt_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')

先看一下训练函数的参数。其他参数都没什么特别的,只有误差权重l_w_embedding=1,l_w_commitment=0.25需要讨论一下。误差函数有三项,但论文只给了第三项的权重(0.25),默认第二项的权重为1。我在实现时把第二项的权重l_w_embedding也加上了。

1
2
3
4
5
6
7
8
9
10
def train_vqvae(model: VQVAE,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/model.pth',
batch_size=64,
dataset_type='MNIST',
lr=1e-3,
n_epochs=100,
l_w_embedding=1,
l_w_commitment=0.25):

再来把函数体过一遍。一开始,我们可以用传来的参数把dataloader初始化一下。

1
2
3
4
5
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)

再把模型的状态调好,并准备好优化器和算均方误差的函数。

1
2
3
4
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr)
mse_loss = nn.MSELoss()

准备好变量后,进入训练循环。训练的过程比较常规,唯一要注意的就是误差计算部分。由于我们把复杂的逻辑都放在了模型类中,这里我们可以直接先用model(x)得到重建图像x_hat和算误差的ze, zq,再根据论文里的公式算3个均方误差,最后求一个加权和,代码比较简明。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for e in range(n_epochs):
for x in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)

x_hat, ze, zq = model(x)
l_reconstruct = mse_loss(x, x_hat)
l_embedding = mse_loss(ze.detach(), zq)
l_commitment = mse_loss(ze, zq.detach())
loss = l_reconstruct + \
l_w_embedding * l_embedding + l_w_commitment * l_commitment
optimizer.zero_grad()
loss.backward()
optimizer.step()

训练完毕后,我们可以用下面的代码来测试VQVAE的重建效果。所谓重建,就是模拟训练的过程,随机取一些图片,先编码后解码,看解码出来的图片和原图片是否一致。为了获取重建后的图片,我们只需要直接执行前向传播函数model(x)即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def reconstruct(model, x, device, dataset_type='MNIST'):
model.to(device)
model.eval()
with torch.no_grad():
x_hat, _, _ = model(x)
n = x.shape[0]
n1 = int(n**0.5)
x_cat = torch.concat((x, x_hat), 3)
x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)
x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)
if dataset_type == 'CelebA' or dataset_type == 'CelebAHQ':
x_cat = cv2.cvtColor(x_cat, cv2.COLOR_RGB2BGR)
cv2.imwrite(f'work_dirs/vqvae_reconstruct_{dataset_type}.jpg', x_cat)

vqvae = ...
dataloader = get_dataloader(...)
img = next(iter(dataloader)).to(device)
reconstruct(vqvae, img, device, cfg['dataset_type'])

训练压缩图像生成模型 PixelCNN

有了一个VQVAE后,我们要用另一个模型对VQVAE的离散空间采样,也就是训练一个能生成压缩图片的模型。我们可以按照VQVAE论文的方法,使用PixelCNN来生成压缩图片。

PixelCNN 的原理及实现方法就不在这里过多介绍了。详情可以参见我之前的PixelCNN解读文章。简单来说,PixelCNN给每个像素从左到右,从上到下地编了一个序号,让每个像素仅由之前所有像素决定。采样时,PixelCNN按序号从左上到右下逐个生成图像的每一个像素;训练时,PixelCNN使用了某种掩码机制,使得每个像素只能看到编号更小的像素,并行地输出每一个像素的生成结果。

PixelCNN具体的训练示意图如下。模型的输入是一幅图片,每个像素的取值是0~255;模型给图片的每个像素输出了一个概率分布,即表示此处颜色取0,取1,……,取255的概率。由于神经网络假设数据的输入符合标准正态分布,我们要在数据输入前把整型的颜色转换成0~1之间的浮点数。最简单的转换方法是除以255。

以上是训练PixelCNN生成普通图片的过程。而在训练PixelCNN生成压缩图片时,上述过程需要修改。压缩图片的取值是离散编码。离散编码和颜色值不同,它不是连续的。你可以说颜色1和颜色0、2相近,但不能说离散编码1和离散编码0、2相近。因此,为了让PixelCNN建模离散编码,需要把原来的除以255操作换成一个嵌入层,使得网络能够读取离散编码。

反映在代码中,假设我们已经有了一个普通的PixelCNN模型GatedPixelCNN,我们需要在整个模型的最前面套一个嵌入层,嵌入层的嵌入个数等于离散编码的个数(color_level),嵌入长度等于模型的特征长度(p)。由于嵌入层会直接输出一个长度为p的向量,我们还需要把第一个模块的输入通道数改成p

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from dldemos.pixelcnn.model import GatedPixelCNN, GatedBlock

import torch.nn as nn


class PixelCNNWithEmbedding(GatedPixelCNN):

def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__(n_blocks, p, linear_dim, bn, color_level)
self.embedding = nn.Embedding(color_level, p)
self.block1 = GatedBlock('A', p, p, bn)

def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 3, 1, 2).contiguous()
return super().forward(x)

有了一个能处理离散编码的PixelCNN后,我们可以用下面的代码来训练PixelCNN。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def train_generative_model(vqvae: VQVAE,
model,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/gen_model.pth',
dataset_type='MNIST',
batch_size=64,
n_epochs=50):
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)
vqvae.to(device)
vqvae.eval()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
loss_fn = nn.CrossEntropyLoss()
tic = time.time()
for e in range(n_epochs):
total_loss = 0
for x in dataloader:
current_batch_size = x.shape[0]
with torch.no_grad():
x = x.to(device)
x = vqvae.encode(x)

predict_x = model(x)
loss = loss_fn(predict_x, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), ckpt_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')
gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],
cfg['pixelcnn_dim'],
cfg['pixelcnn_linear_dim'], True,
cfg['n_embedding'])
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
train_generative_model(vqvae,
gen_model,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['gen_model_path'],
dataset_type=cfg['dataset_type'],
batch_size=cfg['batch_size_2'],
n_epochs=cfg['n_epochs_2'])

训练部分的核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
loss_fn = nn.CrossEntropyLoss()
for x in dataloader:
current_batch_size = x.shape[0]
with torch.no_grad():
x = x.to(device)
x = vqvae.encode(x)

predict_x = model(x)
loss = loss_fn(predict_x, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()

这段代码的意思是说,从训练集里随机取图片x,再将图片压缩成离散编码x = vqvae.encode(x)。这时,x既是PixelCNN的输入,也是PixelCNN的拟合目标。把它输入进PixelCNN,PixelCNN会输出每个像素的概率分布。用交叉熵损失函数约束输出结果即可。

训练完毕后,我们可以用下面的函数来完成整套图像生成流水线。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def sample_imgs(vqvae: VQVAE,
gen_model,
img_shape,
n_sample=81,
device='cuda',
dataset_type='MNIST'):
vqvae = vqvae.to(device)
vqvae.eval()
gen_model = gen_model.to(device)
gen_model.eval()

C, H, W = img_shape
H, W = vqvae.get_latent_HW((C, H, W))
input_shape = (n_sample, H, W)
x = torch.zeros(input_shape).to(device).to(torch.long)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = gen_model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist, 1)
x[:, i, j] = pixel[:, 0]

imgs = vqvae.decode(x)

imgs = imgs * 255
imgs = imgs.clip(0, 255)
imgs = einops.rearrange(imgs,
'(n1 n2) c h w -> (n1 h) (n2 w) c',
n1=int(n_sample**0.5))

imgs = imgs.detach().cpu().numpy().astype(np.uint8)
if dataset_type == 'CelebA' or dataset_type == 'CelebAHQ':
imgs = cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR)

cv2.imwrite(f'work_dirs/vqvae_sample_{dataset_type}.jpg', imgs)

抛掉前后处理,和图像生成有关的代码如下。一开始,我们要随便创建一个空图片x,用于储存PixelCNN生成的压缩图片。之后,我们按顺序遍历每个像素,把当前图片输入进PixelCNN,让PixelCNN预测下一个像素的概率分布prob_dist。我们再用torch.multinomial从概率分布中采样,把采样的结果填回图片。遍历结束后,我们用VQVAE的解码器把压缩图片变成真实图片。

1
2
3
4
5
6
7
8
9
10
11
12
13
C, H, W = img_shape
H, W = vqvae.get_latent_HW((C, H, W))
input_shape = (n_sample, H, W)
x = torch.zeros(input_shape).to(device).to(torch.long)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = gen_model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist, 1)
x[:, i, j] = pixel[:, 0]

imgs = vqvae.decode(x)

至此,我们已经实现了用VQVAE做图像生成的四个任务:训练VQVAE、重建图像、训练PixelCNN、随机生成图像。完整的main函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
if __name__ == '__main__':
os.makedirs('work_dirs', exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('-c', type=int, default=0)
parser.add_argument('-d', type=int, default=0)
args = parser.parse_args()
cfg = get_cfg(args.c)

device = f'cuda:{args.d}'

img_shape = cfg['img_shape']

vqvae = VQVAE(img_shape[0], cfg['dim'], cfg['n_embedding'])
gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],
cfg['pixelcnn_dim'],
cfg['pixelcnn_linear_dim'], True,
cfg['n_embedding'])
# 1. Train VQVAE
train_vqvae(vqvae,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['vqvae_path'],
batch_size=cfg['batch_size'],
dataset_type=cfg['dataset_type'],
lr=cfg['lr'],
n_epochs=cfg['n_epochs'],
l_w_embedding=cfg['l_w_embedding'],
l_w_commitment=cfg['l_w_commitment'])

# 2. Test VQVAE by visualizaing reconstruction result
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
dataloader = get_dataloader(cfg['dataset_type'],
16,
img_shape=(img_shape[1], img_shape[2]))
img = next(iter(dataloader)).to(device)
reconstruct(vqvae, img, device, cfg['dataset_type'])

# 3. Train Generative model (Gated PixelCNN in our project)
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))

train_generative_model(vqvae,
gen_model,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['gen_model_path'],
dataset_type=cfg['dataset_type'],
batch_size=cfg['batch_size_2'],
n_epochs=cfg['n_epochs_2'])

# 4. Sample VQVAE
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
gen_model.load_state_dict(torch.load(cfg['gen_model_path']))
sample_imgs(vqvae,
gen_model,
cfg['img_shape'],
device=device,
dataset_type=cfg['dataset_type'])

实验

VQVAE有两个超参数:嵌入个数n_embedding、特征向量长度dim。论文中n_embedding=512dim=256。而经我实现发现,用更小的参数量也能达到不错的效果。

所有实验的配置文件我都放在了该项目目录下config.py文件中。对于MNIST数据集,我使用的模型超参数为:dim=32, n_embedding=32。VQVAE重建结果如下所示。可以说重建得几乎完美(每对图片左图为原图,右图为重建结果)。

而对于CelebAHQ数据集,我测试了不同输入尺寸下的不同VQVAE,共有4组配置。

  1. shape=(3, 128, 128) dim=128 n_embedding=64
  2. shape=(3, 128, 128) dim=128 n_embedding=128
  3. shape=(3, 64, 64) dim=128 n_embedding=64
  4. shape=(3, 64, 64) dim=128 n_embedding=32

实验的结果很好预测。对于同尺寸的图片,嵌入数越多重建效果越好。这里我只展示下第一组和第二组的重建结果。

可以看出,VQVAE的重建效果还不错。但由于只使用了均方误差,重建图片在细节上还是比较模糊。重建效果还是很重要的,它决定了该方法做图像生成的质量上限。后续有很多工作都试图提升VQVAE的重建效果。

接下来来看一下随机图像生成的实验。PixelCNN主要有模块数n_blocks、特征长度dim,输出线性层特征长度linear_dim这三个超参数。其中模块数一般是固定的,而输出线性层就被用了一次,其特征长度的影响不大。最需要调节的是特征长度dim。对于MNIST,我的超参数设置为

  • n_blocks=15 dim=128 linear_dim=32.

对于CelebAHQ,我的超参数设置为

  • n_blocks=15 dim=384 linear_dim=256.

PixelCNN的训练时间主要由输入图片尺寸和dim决定,训练难度主要由VQVAE的嵌入个数(即多分类的类别数)决定。PixelCNN训起来很花时间。如果时间有限,在CelebAHQ上建议只训练最小最简单的第4组配置。我在项目中提供了PixelCNN的并行训练脚本,比如用下面的命令可以用4张卡在1号配置下并行训练。

1
torchrun --nproc_per_node=4 dldemos/VQVAE/dist_train_pixelcnn.py -c 1

来看一下实验结果。MNIST上的采样结果还是非常不错的。

CelebAHQ上的结果会差一点。以下是第4组配置(图像边长64,嵌入数32)的采样结果。大部分图片都还行,起码看得出是一张人脸。但64x64的图片本来就分辨率不高,加上VQVAE解码的损耗,放大来看人脸还是比较模糊的。

第1组配置(图像边长128,嵌入数64)的PixelCNN实在训练得太慢了,我只训了一个半成品模型。由于部分生成结果比较吓人,我只挑了几个还能看得过去的生成结果。可以看出,如果把模型训完的话,边长128的模型肯定比边长64的模型效果更好。

参考资料

网上几乎找不到在CelebAHQ上训练的VQVAE PyTorch项目。我在实现这份代码时,参考了以下项目:

实验经历分享

别看VQVAE的代码不难,我做这些实验时还是经历了不少波折的。

一开始,我花一天就把代码写完了,并完成了MNIST上的实验。我觉得在MNIST上做实验的难度太低,不过瘾,就准备把数据集换成CelebA再做一些实验。结果这一做就是两个星期。

换成CelebA后,我碰到的第一个问题是VQVAE训练速度太慢。我尝试减半模型参数,训练时间却减小得不明显。我大致猜出是数据读取占用了大量时间,用性能分析工具一查,果然如此。原来我在DataLoader中一直只用了一个线程,加上num_workers=4就好了。我还把数据集打包成LMDB格式进一步加快数据读取速度。

之后,我又发现VQVAE在CelebA上的重建效果很差。我尝试增加模型参数,没起作用。我又怀疑是64x64的图片质量太低,模型学不到东西,就尝试把输入尺寸改成128x128,并把数据集从CelebA换成CelebAHQ,重建效果依然不行。我调了很多参数,发现了一些奇怪的现象:在嵌入层前使用和不使用BatchNorm对结果的影响很大,且显式初始化嵌入层会让模型的误差一直居高不下。我实在是找不到问题,就拿代码对着别人的PyTorch实现一行一行比较过去。总算,我发现我在使用嵌入层时是用vq_embedding.weight.data[x](因为前面已经获取了这个矩阵,这样写比较自然),别人是用vq_embedding(x)。我的写法会把嵌入层排除在梯度计算外,嵌入层根本得不到优化。我说怎么换了一个嵌入层的初始化方法模型就根本训不动了。改完bug之后,只训了5个epoch,新模型的误差比原来训练数小时的模型要低了。新模型的重建效果非常好。

总算,任务完成了一半,现在只剩PixelCNN要训练了。我先尝试训练输入为128x128,嵌入数64的模型,采样结果很差。为了加快实验速度,我把输入尺寸减小到64x64,再次训练,采样结果还是不行。根据我之前的经验,PixelCNN的训练难度主要取决于类别数。于是,我把嵌入的数量从64改成了32,并大幅增加PixelCNN的参数量,再次训练。过了很久,训练误差终于降到0.08左右。我一测,这次的采样结果还不错。

这样看来,之前的采样效果不好,是输入128x128,嵌入数64的实验太难了。我毕竟只是想做一个demo,在一个小型实验上成功就行了,没必要花时间去做更耗时的实验。按理说,我应该就此收手。但是,我就是咽不下这一口气,就是想在128x128的实验上成功。我再次加大了PixelCNN的参数量,用128x128的配置,大火慢炖,训练了一天一夜。第二天一早起来,我看到这回的误差也降到了0.08。上次的实验误差降到这个程度时实验已经成功了。我迫不及待地去测试采样效果,却发现采样效果还是稀烂。没办法,我选择投降,开始写这篇文章,准备收工。

写到PixelCNN介绍的那一章节时,我正准备讲解代码。看到PixelCNN训练之前预处理除以color_level那一行时,我楞了一下:这行代码是用来做什么的来着?这段代码全是从PixelCNN项目里复制过来的。当时是做普通图片的图像生成,所以要对输入颜色做一个预处理,把整数颜色变成0~1之间的浮点数。但现在是在生成压缩图片,不能这样处理啊!我恍然大悟,知道是在处理离散输入时做错了。应该多加一个嵌入层,把离散值转换成向量。由于VQVAE的重点不在生成模型上,原论文根本没有强调PixelCNN在离散编码上的实现细节。网上几乎所有文章也都没谈这一点。因此,我在实现PixelCNN时,直接不假思索地把原来的代码搬了过来,根本没想过这种地方会出现bug。

把这处bug改完后,我再次开启训练。这下所有模型的采样结果都正常了。误差降到0.5左右就已经有不错的采样结果了,原来我之前把误差降到0.08完全是无用功。太气人了。

这次的实验让我学到了很多东西。首先是PyTorch编程上的一些注意事项:

  • 调用embedding.weight.data[x]是传不了梯度的。
  • 如果读数据时有费时的处理操作(读写硬盘、解码),要在Dataloader里设置num_workers

另外,在测试一个模型是否实现成功时有一个重要的准则:

  • 不要仅在简单的数据集(如MNIST)上测试。测试成功可能只是暴力拟合的结果。只有在一个难度较大的数据集上测试成功才能说模型没有问题。

在观察模型是否训成功时,还需要注意:

  • 训练误差降低不代表模型更优。训练误差的评价方法和模型实际使用方法可能完全不同。不能像我这样偷懒不加测试指标。

除了学到的东西外,我还有一些感想。在别人的项目的基础上修改、照着他人代码复现、完全自己动手从零开始写,对于深度学习项目来说,这三种实现方式的难度是依次递增的。改别人的项目,你可能去配置文件里改一两个数字就行了。而照着他人代码复现,最起码你能把代码改成和他人的代码一模一样,然后再去比较哪一块错了。自己动手写,则是有bug都找不到可以参考的地方了。说深度学习的算法难以调试,难就难在这里。效果不好,你很难说清是训练代码错了、超参数没设置好、训练流程错了,或是测试代码错了。可以出错的地方太多了,通常的代码调试手段难以用在深度学习项目上。

对于想要在深度学习上有所建树的初学者,我建议一定要从零动手复现项目。很多工程经验是难以总结的,只有踩了一遍坑才能知道。除了凭借经验外,还可以掌握一些特定的工程方法来减少bug的出现。比如运行训练之前先拿性能工具分析一遍,看看代码是否有误,是否可以提速;又比如可以训练几步后看所有可学习参数是否被正确修改。

2022年中旬,以扩散模型为核心的图像生成模型将AI绘画带入了大众的视野。实际上,在更早的一年之前,就有了一个能根据文字生成高清图片的模型——VQGAN。VQGAN不仅本身具有强大的图像生成能力,更是传承了前作VQVAE把图像压缩成离散编码的思想,推广了「先压缩,再生成」的两阶段图像生成思路,启发了无数后续工作。

VQGAN生成出的高清图片

在这篇文章中,我将对VQGAN的论文和源码中的关键部分做出解读,提炼出VQGAN中的关键知识点。由于VQGAN的核心思想和VQVAE如出一辙,我不会过多地介绍VQGAN的核心思想,强烈建议读者先去学懂VQVAE,再来看VQGAN。

VQGAN 核心思想

VQGAN的论文名为Taming Transformers for High-Resolution Image Synthesis,直译过来是「驯服Transformer模型以实现高清图像合成」。可以看出,该方法是在用Transformer生成图像。可是,为什么这个模型叫做VQGAN,是一个GAN呢?这是因为,VQGAN使用了两阶段的图像生成方法:

  • 训练时,先训练一个图像压缩模型(包括编码器和解码器两个子模型),再训练一个生成压缩图像的模型。
  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型复原成真实图像。

其中,第一个图像压缩模型叫做VQGAN,第二个压缩图像生成模型是一个基于Transformer的模型。

为什么会有这种乍看起来非常麻烦的图像生成方法呢?要理解VQGAN的这种设计动机,有两条路线可以走。两条路线看待问题的角度不同,但实际上是在讲同一件事。

第一条路线是从Transformer入手。Transformer已经在文本生成领域大展身手。同时,Transformer也在视觉任务中开始崭露头角。相比擅长捕捉局部特征的CNN,Transformer的优势在于它能更好地融合图像的全局信息。可是,Transformer的自注意力操作开销太大,只能生成一些分辨率较低的图像。因此,作者认为,可以综合CNN和Transformer的优势,先用基于CNN的VQGAN把图像压缩成一个尺寸更小、信息更丰富的小图像,再用Transformer来生成小图像。

第二条路线是从VQVAE入手。VQVAE是VQGAN的前作,它有着和VQGAN一模一样两阶段图像生成方法。不同的是,VQVAE没有使用GAN结构,且其配套的压缩图像生成模型是基于CNN的。为提升VQVAE的生成效果,作者提出了两项改进策略:1) 图像压缩模型VQVAE仅使用了均方误差,压缩图像的复原结果较为模糊,可以把图像压缩模型换成GAN;2) 在生成压缩图片这个任务上,基于CNN的图像生成模型比不过Transformer,可以用Transformer代替原来的CNN。

第一条思路是作者在论文的引言中描述的,听起来比较高大上;而第二条思路是读者读过文章后能够自然总结出来的,相对来说比较清晰易懂。如果你已经理解了VQVAE,你能通过第二条思路瞬间弄懂VQGAN的原理。说难听点,VQGAN就是一个改进版的VQVAE。然而,VQGAN的改进非常有效,且使用了若干技巧来实现带约束(比如根据文字描述)的高清图像生成,有非常多地方值得学习。

在下文中,我将先补充VQVAE的背景以方便讨论,再介绍VQGAN论文的四大知识点:VQGAN的设计细节、生成压缩图像的Transformer的设计细节、带约束图像生成的实现方法、高清图像生成的实现方法。

VQVAE 背景知识补充

VQVAE的学习目标是用一个编码器把图像压缩成离散编码,再用一个解码器把图像尽可能地还原回原图像。

通俗来说,VQVAE就是把一幅真实图像压缩成一个小图像。这个小图像和真实图像有着一些相同的性质:小图像的取值和像素值(0-255的整数)一样,都是离散的;小图像依然是二维的,保留了某些空间信息。因此,VQVAE的示意图画成这样会更形象一些:

但小图像和真实图像有一个关键性的区别:与像素值不同,小图像的离散取值之间没有关联。真实图像的像素值其实是一个连续颜色的离散采样,相邻的颜色值也更加相似。比如颜色254和颜色253和颜色255比较相似。而小图像的取值之间是没有关联的,你不能说编码为1与编码为0和编码为2比较相似。由于神经网络不能很好地处理这种离散量,在实际实现中,编码并不是以整数表示的,而是以类似于NLP中的嵌入向量的形式表示的。VAE使用了嵌入空间(又称codebook)来完成整数序号到向量的转换。

为了让任意一个编码器输出向量都变成一个固定的嵌入向量,VQVAE采取了一种离散化策略:把每个输出向量$z_e(x)$替换成嵌入空间中最近的那个向量$z_q(x)$。$z_e(x)$的离散编码就是$z_q(x)$在嵌入空间的下标。这个过程和把254.9的输出颜色值离散化成255的整数颜色值的原理类似。

VQVAE的损失函数由两部分组成:重建误差和嵌入空间误差。

其中,重建误差就是输入和输出之间的均方误差。

嵌入空间误差为解码器输出向量$z_e(x)$和它在嵌入空间对应向量$z_q(x)$的均方误差。

作者在误差中还使用了一种「停止梯度」的技巧。这个技巧在VQGAN中被完全保留,此处就不过多介绍了。

图像压缩模型 VQGAN

回顾了VQVAE的背景知识后,我们来正式认识VQGAN的几个创新点。第一点,图像压缩模型VQVAE被改进成了VQGAN。

一般VAE重建出来出来的图像都会比较模糊。这是因为VAE只使用了均方误差,而均方误差只能保证像素值尽可能接近,却不能保证图像的感知效果更加接近。为此,作者把GAN的一些方法引入VQVAE,改造出了VQGAN。

具体来说,VQGAN有两项改进。第一,作者用感知误差(perceptual loss)代替原来的均方误差作为VQGAN的重建误差。第二,作者引入了GAN的对抗训练机制,加入了一个基于图块的判别器,把GAN误差加入了总误差。

计算感知误差的方法如下:把两幅图像分别输入VGG,取出中间某几层卷积层的特征,计算特征图像之间的均方误差。如果你之前没学过相关知识,请搜索”perceptual loss”。

基于图块的判别器,即判别器不为整幅图输出一个真或假的判断结果,而是把图像拆成若干图块,分别输出每个图块的判断结果,再对所有图块的判断结果取一个均值。这只是GAN的一种改进策略而已,没有对GAN本身做太大的改动。如果你之前没学过相关知识,请搜索”PatchGAN”。

这样,总的误差可以写成:

其中,$\lambda$是控制两种误差比例的权重。作者在论文中使用了一个公式来自适应地设置$\lambda$。和普通的GAN一样,VQGAN的编码器、解码器(即生成器)、codebook会最小化误差,判别器会最大化误差。

用VQGAN代替VQVAE后,重建图片中的模糊纹理清晰了很多。

有了一个保真度高的图像压缩模型,我们可以进入下一步,训练一个生成压缩图像的模型。

基于 Transformer 的压缩图像生成模型

如前所述,经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。因此,之前的VQVAE使用了一个能建模离散颜色的PixelCNN模型作为压缩图像生成模型。但PixelCNN的表现不够优秀。

恰好,功能强大的Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。

Transformer 随机生成句子的过程

为了让Transformer生成图像,我们可以把生成句子的一个个单词,变成生成压缩图像的一个个像素。但是,要让Transformer生成二维图像,还需要克服一个问题:在生成句子时,Transformer会先生成第一个单词,再根据第一个单词生成第二个单词,再根据第一、第二个单词生成第三个单词……。也就是说,Transformer每次会根据之前所有的单词来生成下一单词。而图像是二维数据,没有先后的概念,怎样让像素和文字一样有先后顺序呢?

VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第$i$步,Transformer会根据前$i - 1$个像素$s_{ < i}$生成第$i$个像素$s_i$,

带约束的图像生成

在生成新图像时,我们更希望模型能够根据我们的需求生成图像。比如,我们希望模型生成「一幅优美的风景画」,又或者希望模型在一幅草图的基础上作画。这些需求就是模型的约束。为了实现带约束的图像生成,一般的做法是先有一个无约束(输入是随机数)的图像生成模型,再在这个模型的基础上把一个表示约束的向量插入进图像生成的某一步。

把约束向量插入进模型的方法是需要设计的,插入约束向量的方法往往和模型架构有着密切关系。比如假设一个生成模型是U-Net架构,我们可以把约束向量和当前特征图拼接在一起,输入进U-Net的每一大层。

为了实现带约束的图像生成,VQGAN的作者再次借鉴了Transformer实现带约束文字生成的方法。许多自然语言处理任务都可以看成是带约束的文字生成。比如机器翻译,其实可以看成在给定一种语言的句子的前提下,让模型「随机」生成一个另一种语言的句子。比如要把「简要访问非洲」翻译成英语,我们可以对之前无约束文字生成的Transformer做一些修改。

也就是说,给定约束的句子$c$,在第$i$步,Transformer会根据前$i-1$个输出单词$s_{ < i}$以及$c$生成第$i$个单词$s_i$。表示约束的单词被添加到了所有输出之前,作为这次「随机生成」的额外输入。

上述方法并不是唯一的文字生成方法。这种文字生成方法被称为”decoder-only”。实际上,也有使用一个编码器来额外维护约束信息的文字生成方法。最早的Transformer就用到了带编码器的方法。

我们同样可以把这种思想搬到压缩图像生成里。比如对于MNIST数据集,我们希望模型只生成0~9这些数字中某一个数字的手写图像。也就是说,约束是类别信息,约束的取值是0~9。我们就可以把这个0~9的约束信息添加到Transformer的输入$s_{ < i}$之前,以实现由类别约束的图像生成。

但这种设计又会产生一个新的问题。假设约束条件不能简单地表示成整数,而是一些其他类型的数据,比如语义分割图像,那该怎么办呢?对于这种以图像形式表示的约束,作者的做法是,再训练另一个VQGAN,把约束图像压缩成另一套压缩图片。这一套压缩图片和生成图像的压缩图片有着不同的codebook,就像两种语言有着不同的单词一样。这样,约束图像也变成了一系列的整数,可以用之前的方法进行带约束图像生成了。

生成高清图像

由于Transformer注意力计算的开销很大,作者在所有配置中都只使用了$16 \times 16$的压缩图像,再增大压缩图像尺寸的话计算资源就不够了。而另一方面,每张图像在VQGAN中的压缩比例是有限的。如果图像压缩得过多,则VQGAN的重建质量就不够好了。因此,设边长压缩了$f$倍,则该方法一次能生成的图片的最大尺寸是$16f \times 16f$。在多项实验中,$f=16$的表现都较好。这样算下来,该方法一次只能生成$256 \times 256$的图片。这种尺寸的图片还称不上高清图片。

为了生成更大尺寸的图片,作者先训练好了一套能生成$256 \times 256$的图片的VQGAN+Transformer,再用了一种基于滑动窗口的采样机制来生成大图片。具体来说,作者把待生成图片划分成若干个$16\times16$像素的图块,每个图块对应压缩图像的一个像素。之后,在每一轮生成时,只有待生成图块周围的$16\times16$个图块($256\times256$个像素)会被输入进VQGAN和Transformer,由Transformer生成一个新的压缩图像像素,再把该压缩图像像素解码成图块。(在下面的示意图中,每个方块是一个图块,transformer的输入是$3\times3$个图块)

这个滑动窗口算法不是那么好理解,需要多想一下才能理解它的具体做法。在理解这个算法时,你可能会有这样的问题:上面的示意图中,待生成像素有的时候在最左边,有的时候在中间,有的时候在右边,每次约束它的像素都不一样。这么复杂的约束逻辑怎么编写?其实,Transformer自动保证了每个像素只会由之前的像素约束,而看不到后面的像素。因此,在实现时,只需要把待生成像素框起来,直接用Transformer预测待生成像素即可,不需要编写额外的约束逻辑。

如果你没有学过Transformer的话,理解这部分会有点困难。Transformer可以根据第1~k-1个像素并行地生成第2~k个像素,且保证生成每个像素时不会偷看到后面像素的信息。因此,假设我们要生成第i个像素,其实是预测了所有第2~k个像素的结果,再取出第i个结果,填回待生成图像。

由于论文篇幅有限,作者没有对滑动窗口机制做过多的介绍,也没有讲带约束的滑动窗口是怎么实现的。如果你在理解这一部分时碰到了问题,不用担心,这很正常。稍后我们会在代码阅读章节彻底理解滑动窗口的实现方法。我也是看了代码才看懂此处的做法。

作者在论文中解释了为什么用滑动窗口生成高清图像是合理的。作者先是讨论了两种情况,只要满足这两种情况中的任意一种,拿滑动窗口生成图像就是合理的。第一种情况是数据集的统计规律是几乎空间不变,也就是说训练集图片每$256\times256$个像素的统计规律是类似的。这和我们拿$3\times3$卷积卷图像是因为图像每$3\times3$个像素的统计规律类似的原理是一样的。第二种情况是有空间上的约束信息。比如之前提到的用语义分割图来指导图像生成。由于语义分割也是一张图片,它给每个待生成像素都提供了额外信息。这样,哪怕是用滑动窗口,在局部语义的指导下,模型也足以生成图像了。

若是两种情况都不满足呢?比如在对齐的人脸数据集上做无约束生成。在对齐的人脸数据集里,每张图片中人的五官所在的坐标是差不多的,图片的空间不变性不满足;做无约束生成,自然也没有额外的空间信息。在这种情况下,我们可以人为地添加一个坐标约束,即从左到右、从上到下地给每个像素标一个序号,把每个滑动窗口里的坐标序号做为约束。有了坐标约束后,就还原成了上面的第二种情况,每个像素有了额外的空间信息,基于滑动窗口的方法依然可行。

学完了论文的四大知识点,我们知道VQGAN是怎么根据约束生成高清图像的了。接下来,我们来看看论文的实验部分,看看作者是怎么证明方法的有效性的。

实验

在实验部分,作者先是分别验证了基于Transformer的压缩图像生成模型较以往模型的优越性(4.1节)、VQGAN较以往模型的优越性(4.4节末尾)、使用VQGAN做图像压缩的必要性及相关消融实验(4.3节),再把整个生成方法综合起来,在多项图像生成任务上与以往的图像生成模型做定量对比(4.4节),最后展示了该方法惊艳的带约束生成效果(4.2节)。

在论文4.1节中,作者验证了基于Transformer的压缩图像生成模型的有效性。之前,压缩图像都是使用能输出离散分布的PixelCNN系列模型来生成的。PixelCNN系列的最强模型是PixelSNAIL。为确保公平,作者对比了相同训练时间、相同训练步数下两个网络在不同训练集下的负对数似然(NLL)指标。结果表明,基于Transformer的模型确实训练得更快。

对于直接能建模离散分布的模型来说,NLL就是交叉熵损失函数。

在论文4.4节末尾,作者将VQGAN和之前的图像压缩模型对比,验证了引入感知误差和GAN结构的有效性。作者汇报了各模型重建图像集与原数据集(ImageNet的训练集和验证集)的FID(指标FID是越低越好)。同时,结果也说明,增大codebook的尺寸或者编码种类都能提升重建效果。

在论文4.3节中,作者验证了使用VQGAN的必要性。作者训了两个模型,一个直接让Transformer做真实图像生成,一个用VQGAN把图像边长压缩2倍,再用Transformer生成压缩图像。经比较,使用了VQGAN后,图像生成速度快了10多倍,且图像生成效果也有所提升。

另外,作者还做了有关图像边长压缩比例$f$的消融实验。作者固定让Transformer生成$16 \times 16$的压缩图片,即每次训练时用到的图像尺寸都是$16f \times 16f$。之后,作者训练训练了不同$f$下的模型,用各个模型来生成图片。结果显示$f=16$时效果最好。这是因为,在固定Transformer的生成分辨率的前提下,$f$越小,Transformer的感受野越小。如果Transformer的感受野过小,就学习不到足够的信息。

在论文4.4节中,作者探究了VQGAN+Transformer在多项基准测试(benchmark)上的结果。

首先是语义图像合成(根据语义分割图像来生成)任务。本文的这套方法还不错。

接着是人脸生成任务。这套方法表现还行,但还是比不过专精于某一任务的GAN。

作者还比较了各模型在ImageNet上的生成结果。这一比较的数据量较多,欢迎大家自行阅读原论文。

在论文4.2节中,作者展示了多才多艺的VQGAN+Transformer在各种约束下的图像生成结果。这些图像都是按照默认配置生成的,大小为$256\times256$。

作者还展示了使用了滑动窗口算法后,模型生成的不同分辨率的图像。

本文开头的那张高清图片也来自论文。

总结

VQGAN是一个改进版的VQVAE,它将感知误差和GAN引入了图像压缩模型,把压缩图像生成模型替换成了更强大的Transformer。相比纯种的GAN(如StyleGAN),VQGAN的强大之处在于它支持带约束的高清图像生成。VQGAN借助NLP中”decoder-only”策略实现了带约束图像生成,并使用滑动窗口机制实现了高清图像生成。虽然在某些特定任务上VQGAN还是落后于其他GAN,但VQGAN的泛化性和灵活性都要比纯种GAN要强。它的这些潜力直接促成了Stable Diffusion的诞生。

如果你是读完了VQVAE再来读的VQGAN,为了完全理解VQGAN,你只需要掌握本文提到的4个知识点:VQVAE到VQGAN的改进方法、使用Transformer做图像生成的方法、使用”decoder-only”策略做带约束图像生成的方法、用滑动滑动窗口生成任意尺寸的图片的思想。

代码阅读

在代码阅读章节中,我将先简略介绍官方源码的项目结构以方便大家学习,再介绍代码中的几处核心代码。具体来说,我会介绍模型是如何组织配置文件的、模型的定义代码在哪、训练代码在哪、采样代码在哪,同时我会主要分析VQGAN的结构、Transformer的结构、损失函数、滑动窗口采样算法这几部分的代码。

官方源码地址:https://github.com/CompVis/taming-transformers。

官方的Git仓库里有很多很大的图片,且git记录里还藏了一些很大的数据,整个Git仓库非常大。如果你的网络不好,建议以zip形式下载仓库,或者只把代码部分下载下来。

项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
├─assets
├─configs
├─scripts
└─taming
├─data
│ └─conditional_builder
├─models
└─modules
├─diffusionmodules
├─discriminator
├─losses
├─misc
├─transformer
└─vqvae

configs目录下存放的是模型配置文件。VQGAN和Transformer的模型配置是分开来放的。每个模型配置文件都会指向一个Python模型类,比如taming.models.vqgan.VQModel,配置里的参数就是模型类的初始化参数。我们可用通过阅读配置文件找到模型的定义位置。

运行脚本包括根目录下的main.pyscripts文件夹下的脚本。main.py是用于训练的。scripts文件夹下有各种采样脚本和数据集可视化脚本。

taming是源代码的主目录。其data子文件夹下放置了各数据集的预处理代码,models放置了VQGAN和Transformer PyTorch模型的定义代码,modules则放置了模型中用到的模块,主要包括VQGAN编码解码模块(diffusionmodules)、判别器模块(discriminator)、误差模块(losses)、Transformer模块(transformer)、codebook模块(vqvae)。

VQGAN 模型结构

打开configs\faceshq_vqgan.yaml,我们能够找到高清人脸生成任务使用的VQGAN模型配置。我们来学习一下这个模型的定义方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
...

lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
...

从配置文件的target字段中,我们知道VQGAN定义在模块taming.models.vqgan.VQModel中。我们可以打开taming\models\vqgan.py这个文件,查看其中VQModel类的代码。

首先先看一下初始化函数。初始化函数主要是初始化了encoderdecoderlossquantize这几个模块,我们可以从文件开头的import语句中找到这几个模块的定义位置。不过,先不急,我们来继续看一下模型的前向传播函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer

class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap, sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.image_key = image_key
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor

模型的前向传播逻辑非常清晰。self.encoder可以把一张图片变为特征,self.decoder可以把特征变回图片。self.quant_convpost_quant_conv则分别完成了编码器到codebook、codebook到解码器的通道数转换。self.quantize实现了VQVAE和VQGAN中那个找codebook里的最近邻、替换成最近邻的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info

def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

def forward(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff

接下来,我们再看一看VQGAN的各个模块的定义。编码器和解码器的定义都可以在taming\modules\diffusionmodules\model.py里找到。VQGAN使用的编码器和解码器基于DDPM论文中的U-Net架构(而此架构又可以追溯到PixelCNN++的模型架构)。相比于最经典的U-Net,此U-Net每一层由若干个残差块和若干个自注意力块构成。为了把这个U-Net用到VQGAN里,U-Net的下采样部分和上采样部分被拆开,分别做成了VQGAN的编码器和解码器。

此处代码过长,我就只贴出部分关键代码了。以下是编码器的__init__函数和forward函数的关键代码。self.down存储了U-Net各层的模块。对于第i层,down[i].block是所有残差块,down[i].attn是所有自注意力块,down[i].downsample是下采样操作。它们在forward里会被依次调用。解码器的结构与之类似,只不过下采样变成了上采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, **ignore_kwargs):
super().__init__()
...
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)

...


def forward(self, x):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
...

return h

之后,我们再看看离散化层的代码,即把编码器的输出变成codebook里的嵌入的实现代码。作者在taming\modules\vqvae\quantize.py中提供了VQVAE原版的离散化操作以及若干个改进过的离散化操作。我们就来看一下原版的离散化模块VectorQuantizer是怎么实现的。

离散化模块的初始化非常简洁,主要是初始化了一个嵌入层。

1
2
3
4
5
6
7
8
9
10
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta

self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

在前向传播时,作者先是算出了编码器输出z和所有嵌入的距离d,再用argmin算出了最近邻嵌入的下标min_encodings,最后根据下标取出解码器输入z_q。同时,该函数还计算了其他几个可能用到的量,比如和codebook有关的误差 loss。注意,在计算lossz_q时,作者都使用到了停止梯度算子(.detach())。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward(self, z):
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())

## could possible replace this here
# #\start...
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)

min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)

z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
#.........\end


# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)

# preserve gradients
z_q = z + (z_q - z).detach()

# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()

return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

VQGAN的三个主要模块已经看完了。最后,我们来看一下误差的定义。误差的定义在taming\modules\losses\vqperceptual.pyVQLPIPSWithDiscriminator类里。误差类名里的LPIPS(Learned Perceptual Image Patch Similarity,学习感知图像块相似度)就是感知误差的全称,”WithDiscriminator”表示误差是带了判定器误差的。我们来把这两类误差分别看一下。

说实话,这个误差模块乱得一塌糊涂,一边自己在算误差,一边又维护了codebook误差和重建误差的权重,最后会把自己维护的两个误差和其他误差合在一起输出。功能全部耦合在一起。我们就跳过这个类的实现细节,主要关注self.perceptual_lossself.discriminator是怎么调用其他模块的。

1
2
3
4
5
6
7
8
9
10
from taming.modules.losses.lpips import LPIPS
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init

class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, ...):
super().__init__()

self.perceptual_loss = LPIPS().eval()

self.discriminator = NLayerDiscriminator...

感知误差模块在taming\modules\losses\vqperceptual.py文件里。这个文件来自GitHub项目 PerceptualSimilarity。

感知误差可以简单地理解为两张图片在VGG中几个卷积层输出的误差的加权和。加权的权重是可以学习的。作者使用的是已经学习好的感知误差。感知误差的初始化函数如下。其中,self.lin0等模块就是算权重的模块,self.net是VGG。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False

在算误差时,先是把图像inputtarget都输入进VGG,获取各层输出outs0, outs1,再求出两个图像的输出的均方误差diffs,最后用lins给各层误差加权,求和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val

GAN的判别器写在taming\modules\discriminator\model.py文件里。这个文件来自GitHub上的 pytorch-CycleGAN-and-pix2pix 项目。这个判别器非常简单,就是一个全卷积网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d

kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)

def forward(self, input):
"""Standard forward."""
return self.main(input)

Transformer 模型结构

此方法使用的Transformer是GPT2。我们先看一下该项目封装Transformer的模型类taming.models.cond_transformer.Net2NetTransformer,再稍微看一下GPT类taming.modules.transformer.mingpt.GPT的具体实现。

Net2NetTransformer主要是实现了论文中提到的带约束生成。它会把输入x和约束c分别用一个VQGAN转成压缩图像,把图像压扁成一维,再调用GPT。我们来看一下这个类的主要内容。

初始化函数主要是初始化了输入图像的VQGAN self.first_stage_model、约束图像的VQGAN self.cond_stage_model、Transformer self.transformer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Net2NetTransformer(pl.LightningModule):
def __init__(self,
transformer_config,
first_stage_config,
cond_stage_config,
permuter_config=None,
ckpt_path=None,
ignore_keys=[],
first_stage_key="image",
cond_stage_key="depth",
downsample_cond_size=-1,
pkeep=1.0,
sos_token=0,
unconditional=False,
):
super().__init__()
self.be_unconditional = unconditional
self.sos_token = sos_token
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.init_first_stage_from_ckpt(first_stage_config)
self.init_cond_stage_from_ckpt(cond_stage_config)
if permuter_config is None:
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
self.permuter = instantiate_from_config(config=permuter_config)
self.transformer = instantiate_from_config(config=transformer_config)

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.downsample_cond_size = downsample_cond_size
self.pkeep = pkeep

def init_first_stage_from_ckpt(self, config):
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
self.first_stage_model = model

def init_cond_stage_from_ckpt(self, config):
...
self.cond_stage_model = ...

模型的前向传播函数如下。一开始,函数调用encode_to_zencode_to_c,根据self.cond_stage_modelself.first_stage_model把约束图像和输入图像编码成压扁至一维的压缩图像。之后函数做了一个类似Dropout的操作,根据self.pkeep随机替换掉约束编码。最后,函数把约束编码和输入编码拼接起来,使用通常方法调用Transformer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def forward(self, x, c):
# one step to produce the logits
_, z_indices = self.encode_to_z(x)
_, c_indices = self.encode_to_c(c)

if self.training and self.pkeep < 1.0:
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
device=z_indices.device))
mask = mask.round().to(dtype=torch.int64)
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
a_indices = mask*z_indices+(1-mask)*r_indices
else:
a_indices = z_indices

cz_indices = torch.cat((c_indices, a_indices), dim=1)

# target includes all sequence elements (no need to handle first one
# differently because we are conditioning)
target = z_indices
# make the prediction
logits, _ = self.transformer(cz_indices[:, :-1])
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
logits = logits[:, c_indices.shape[1]-1:]

return logits, target

GPT2的结构不是本文的重点,我们就快速把模型结构过一遍了。GPT2的模型定义在taming.modules.transformer.mingpt.GPT里。GPT2的结构并不复杂,就是一个只有解码器的Transformer。前向传播时,数据先通过嵌入层self.tok_emb,再经过若干个Transformer模块self.blocks,最后过一个LayerNorm层self.ln_f和线性层self.head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GPT(nn.Module):

def forward(self, idx, embeddings=None, targets=None):
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector

if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)

t = token_embeddings.shape[1]
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

return logits, loss

每个Transformer块就是非常经典的自注意力加全连接层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(), # nice
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)

def forward(self, x, layer_past=None, return_present=False):
# TODO: check that training still works
if return_present: assert not self.training
# layer past: tuple of length two with B, nh, T, hs
attn, present = self.attn(self.ln1(x), layer_past=layer_past)

x = x + attn
x = x + self.mlp(self.ln2(x))
if layer_past is not None or return_present:
return x, present
return x

基于滑动窗口的带约束图像生成

看完了所有模型的结构,我们最后来学习一下论文中没能详细介绍的滑动窗口算法。在scripts\taming-transformers.ipynb里有一个采样算法的最简实现,我们就来学习一下这份代码。

这份代码可以根据一幅语义分割图像来生成高清图像。一开始,代码会读入模型和语义分割图像。大致的代码为:

1
2
3
4
5
6
7
from taming.models.cond_transformer import Net2NetTransformer
model = Net2NetTransformer(**config.model.params)
from PIL import Image
import numpy as np
segmentation_path = "data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png"
segmentation = Image.open(segmentation_path)
...


之后,代码把约束图像用对应的VQGAN编码进压缩空间,得到c_indices。由于待生成图像为空,我们可以随便生成一个待生成图像的压缩图像z_indices,代码中使用了randint初始化待生成的压缩图像。

1
2
3
4
5
6
7
8
c_code, c_indices = model.encode_to_c(segmentation)
z_indices = torch.randint(codebook_size, z_indices_shape, device=model.device)

idx = z_indices
idx = idx.reshape(z_code_shape[0],z_code_shape[2],z_code_shape[3])

cidx = c_indices
cidx = cidx.reshape(c_code.shape[0],c_code.shape[2],c_code.shape[3])

最后就是最关键的滑动窗口采样部分了。我们先稍微浏览一遍代码,再详细地一行一行看过去。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
temperature = 1.0
top_k = 100

for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

一开始的temperaturetop_k是得到logit后的采样参数,和滑动窗口算法无关。
1
2
temperature = 1.0
top_k = 100

进入生成图像循环后,i, j分别表示压缩图像的竖索引和横索引,i_start, i_end, j_start, j_end是滑动窗口上下左右边界。

1
2
3
4
5
6
7
8
for i in range(0, z_code_shape[2]-0):
...
for j in range(0,z_code_shape[3]-0):
...
i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

为了获取这四个滑动窗口的范围,代码用了若干条件语句计算待生成像素在滑动窗口里的相对位置local_i, local_j

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

得到了滑动窗口的边界后,代码用滑动窗口从约束图像的压缩图像和待生成图像的压缩图像上各取出一个图块,并拼接起来。

1
2
3
4
5
patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)

之后,只需要把拼接的图块直接输入进Transformer,得到输出logits,再用local_i,local_j去输出图块的对应位置取出下一个压缩图像像素的概率分布,就可以随机生成下一个压缩图像像素了。如前文所述,Transformer类会把二维的图块压扁到一维,输入进GPT。同时,GPT会自动保证前面的像素看不到后面的像素,我们不需要人为地指定约束像素。这个地方的调用逻辑其实非常简单。

1
2
3
4
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

最后只要从logits里采样,把采样出的压缩图像像素填入idx,就完成了一步生成。

1
2
3
4
5
6
7
logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

反复执行循环,就能将压缩图像生成完毕。最后将压缩图像过一遍VQGAN的解码器即可得到最终的生成图像。

1
2
x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

参考资料

VQGAN论文:https://arxiv.org/abs/2012.09841

VQGAN GitHub:https://github.com/CompVis/taming-transformers

如果你需要补充学习早期工作,欢迎阅读我之前的文章。

Transformer解读

PixelCNN解读

VQVAE解读

在这篇文章中,我将详细地介绍一个英中翻译 Transformer 的 PyTorch 实现。这篇文章会完整地展示一个深度学习项目的搭建过程,从数据集准备,到模型定义、训练。这篇文章不仅会讲解如何把 Transformer 的论文翻译成代码,还会讲清楚代码实现中的诸多细节,并分享我做实验时碰到的种种坑点。相信初学者能够从这篇文章中学到丰富的知识。

项目网址: https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/Transformer

如果你对 Transformer 的论文不熟,欢迎阅读我之前的文章:Attention Is All You Need (Transformer) 论文精读

数据集准备

我在 https://github.com/P3n9W31/transformer-pytorch 项目中找到了一个较小的中英翻译数据集。数据集只有几KB大小,中英词表只有10000左右,比较适合做Demo。如果要实现更加强大实用的模型,则需要换更大的数据集。但相应地,你要多花费更多的时间来训练。

我在代码仓库中提供了data_load.py文件。执行这个文件后,实验所需要的数据会自动下载到项目目录的data文件夹下。

该数据集由cn.txt, en.txt, cn.txt.vocab.tsv, en.txt.vocab.tsv这四个文件组成。前两个文件包含相互对应的中英文句子,其中中文已做好分词,英文全为小写且标点已被分割好。后两个文件是预处理好的词表。语料来自2000年左右的中国新闻,其第一条的中文及其翻译如下:

1
2
目前 粮食 出现 阶段性 过剩 , 恰好 可以 以 粮食 换 森林 、 换 草地 , 再造 西部 秀美 山川 。
the present food surplus can specifically serve the purpose of helping western china restore its woodlands , grasslands , and the beauty of its landscapes .

词表则统计了各个单词的出现频率。通过使用词表,我们能实现单词和序号的相互转换(比如中文里的5号对应“的”字,英文里的5号对应”the”)。词表的前四个单词是特殊字符,分别为填充字符、频率太少没有被加入词典的词语、句子开始字符、句子结束字符。

1
2
3
4
5
6
7
8
<PAD>	1000000000
<UNK> 1000000000
<S> 1000000000
</S> 1000000000
的 8461
是 2047
和 1836
在 1784
1
2
3
4
5
6
7
8
<PAD>	1000000000
<UNK> 1000000000
<S> 1000000000
</S> 1000000000
the 13680
and 6845
of 6259
to 4292

只要运行一遍data_load.py下好数据后,我们待会就能用load_train_data()来获取已经打成batch的训练数据,并用API获取cn2idx, idx2cn, en2idx, idx2en这四个描述中英文序号与单词转换的词典。我们会在之后的训练代码里见到它们的用法。

Transformer 模型

准备好数据后,接下来就要进入这个项目最重要的部分——Transformer 模型实现了。我将按照代码的执行顺序,从前往后,自底向上地介绍 Transformer 的各个模块:Positional Encoding, MultiHeadAttention, Encoder & Decoder, 最后介绍如何把各个模块拼到一起。在这个过程中,我还会着重介绍一个论文里没有提及,但是代码实现时非常重要的一个细节——<pad>字符的处理。

说实话,用 PyTorch 实现 Transformer 没有什么有变数的地方,大家的代码写得都差不多,我也是参考着别人的教程写的。但是,Transformer 的代码实现中有很多坑。大部分人只会云淡风轻地介绍一下最终的代码成品,不会去讲他们 debug 耗费了多少时间,哪些地方容易出错。而我会着重讲一下代码中的一些细节,以及我碰到过的问题。

Positional Encoding

模型一开始是一个 Embedding 层加一个 Positional Encoding。Embedding 在 PyTorch 里已经有实现了,且不是文章的创新点,我们就直接来看 Positional Encoding 的写法。

求 Positional Encoding,其实就是求一个二元函数的许多函数值构成的矩阵。对于二元函数$PE(pos, i)$,我们要求出$pos \in [0, seqlen - 1], i \in [0, d_{model} - 1]$时所有的函数值,其中,$seqlen$是该序列的长度,$d_{model}$是每一个词向量的长度。

理论上来说,每个句子的序列长度$seqlen$是不固定的。但是,我们可以提前预处理一个$seqlen$很大的 Positional Encoding 矩阵 。每次有句子输入进来,根据这个句子的序列长度,去预处理好的矩阵里取一小段出来即可。

这样,整个类的实现应该如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class PositionalEncoding(nn.Module):

def __init__(self, d_model: int, max_seq_len: int):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(1, max_seq_len, d_model)

self.register_buffer('pe', pe, False)

def forward(self, x: torch.Tensor):
n, seq_len, d_model = x.shape
pe: torch.Tensor = self.pe
assert seq_len <= pe.shape[1]
assert d_model == pe.shape[2]
rescaled_x = x * d_model**0.5
return rescaled_x + pe[:, 0:seq_len, :]

代码中有不少需要讲解的部分。首先,先看一下预处理好的矩阵pe是怎么在__init__中算出来的。pe可以很直接地用两层循环算出来。由于这段预处理代码只会执行一次,相对于冗长的训练时间,哪怕生成pe的代码性能差一点也没事。然而,作为一个编程高手,我准备秀一下如何用并行的方法求出pe

为了并行地求pe,我们要初始化一个二维网格,表示自变量$pos, i$。生成网格可以用下面的代码实现。(由于$i$要分奇偶讨论,$i$的个数是$\frac{d_{model}}{2}$)

1
2
3
i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)

torch.meshgrid用于生成网格。比如torch.meshgrid([0, 1], [0, 1])就可以生成[[(0, 0), (0, 1)], [(1, 0), (1, 1)]]这四个坐标构成的网格。不过,这个函数会把坐标的两个分量分别返回。比如:

1
2
3
i, j = torch.meshgrid([0, 1], [0, 1])
# i: [[0, 0], [1, 1]]
# j: [[0, 1], [0, 1]]

利用这个函数的返回结果,我们可以把pos, two_i套入论文的公式,并行地分别算出奇偶位置的 PE 值。

1
2
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))

有了奇偶处的值,现在的问题是怎么把它们优雅地拼到同一个维度上。我这里先把它们堆成了形状为seq_len, d_model/2, 2的一个张量,再把最后一维展平,就得到了最后的pe矩阵。这一操作等于新建一个seq_len, d_model形状的张量,再把奇偶位置处的值分别填入。

1
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(1, max_seq_len, d_model)

最后,要注意一点。只用 self.pe = pe 记录这个量是不够好的。我们最好用 self.register_buffer('pe', pe, False) 把这个量登记成 torch.nn.Module 的一个存储区(这一步会自动完成self.pe = pe)。这里涉及到 PyTorch 的一些知识了。

PyTorch 的 Module 会记录两类参数,一类是 parameter 可学习参数,另一类是 buffer 不可学习的参数。把变量登记成 buffer 的最大好处是,在使用 model.to(device) 把一个模型搬到另一个设备上时,所有 parameterbuffer 都会自动被搬过去。另外,bufferparameter 一样,也可以被记录到 state_dict 中,并保存到文件里。register_buffer 的第三个参数决定了是否将变量加入 state_dict。由于 pe 可以直接计算,不需要记录,可以把这个参数设成 False

预处理好 pe 后,用起来就很方便了。每次读取输入的序列长度,从中取一段出来即可。

另外,Transformer 给嵌入层乘了个系数$\sqrt{d_{model}}$。为了方便起见,我把这个系数放到了 PositionalEncoding 类里面。

1
2
3
4
5
6
7
def forward(self, x: torch.Tensor):
n, seq_len, d_model = x.shape
pe: torch.Tensor = self.pe
assert seq_len <= pe.shape[1]
assert d_model == pe.shape[2]
rescaled_x = x * d_model**0.5
return rescaled_x + pe[:, 0:seq_len, :]

Scaled Dot-Product Attention

下一步是多头注意力层。为了实现多头注意力,我们先要实现 Transformer 里经典的注意力计算。而在讲注意力计算之前,我还要补充一下 Transformer 中有关 mask 的一些知识。

Transformer 里的 mask

Transformer 最大的特点就是能够并行训练。给定翻译好的第1~n个词语,它默认会并行地预测第2~(n+1)个下一个词语。为了模拟串行输出的情况,第$t$个词语不应该看到第$t+1$个词语之后的信息。

输入信息 输出
(y1, —, —, —) y2
(y1, y2, —, —) y3
(y1, y2, y3, —) y4
(y1, y2, y3, y4) y5

为了实现这一功能,Transformer 在Decoder里使用了掩码。掩码取1表示这个地方的数是有效的,取0表示这个地方的数是无效的。Decoder 里的这种掩码应该是一个上三角全1矩阵。

掩码是在注意力计算中生效的。对于掩码取0的区域,其softmax前的$QK^T$值取负无穷。这是因为,对于softmax

令$x_i=-\infty$可以让它在 softmax 的分母里不产生任何贡献。

以上是论文里提到的 mask,它用来模拟 Decoder 的串行推理。而在代码实现中,还有其他地方会产生 mask。在生成一个 batch 的数据时,要给句子填充 <pad>。这个特殊字符也没有实际意义,不应该对计算产生任何贡献。因此,有 <pad> 的地方的 mask 也应该为0。之后讲 Transformer 模型类时,我会介绍所有的 mask 该怎么生成,这里我们仅关注注意力计算是怎么用到 mask 的。

注意力计算

补充完了背景知识,我们来看注意力计算的实现代码。由于注意力计算没有任何的状态,因此它应该写成一个函数,而不是一个类。我们可以轻松地用 PyTorch 代码翻译注意力计算的公式。(注意,我这里的 mask 表示哪些地方要填负无穷,而不是像之前讲的表示哪些地方有效)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
MY_INF = 1e12

def attention(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
'''
Note: The dtype of mask must be bool
'''
# q shape: [n, heads, q_len, d_k]
# k shape: [n, heads, k_len, d_k]
# v shape: [n, heads, k_len, d_v]
assert q.shape[-1] == k.shape[-1]
d_k = k.shape[-1]
# tmp shape: [n, heads, q_len, k_len]
tmp = torch.matmul(q, k.transpose(-2, -1)) / d_k**0.5
if mask is not None:
tmp.masked_fill_(mask, -MY_INF)
tmp = F.softmax(tmp, -1)
# tmp shape: [n, heads, q_len, d_v]
tmp = torch.matmul(tmp, v)
return tmp

这里有一个很坑的地方。引入了 <pad> 带来的 mask 后,会产生一个新的问题:可能一整行数据都是失效的,softmax 用到的所有 $x_i$ 可能都是负无穷。

这个数是没有意义的。如果用torch.inf来表示无穷大,就会令exp(torch.inf)=0,最后 softmax 结果会出现 NaN,代码大概率是跑不通的。

但是,大多数 PyTorch Transformer 教程压根就没提这一点,而他们的代码又还是能够跑通。拿放大镜仔细对比了代码后,我发现,他们的无穷大用的不是 torch.inf,而是自己随手设的一个极大值。这样,exp(-MY_INF)得到的不再是0,而是一个极小值。softmax 的结果就会等于分母的项数,而不是 NaN,不会有数值计算上的错误。

Multi-Head Attention

有了注意力计算,就可以实现多头注意力层了。多头注意力层是有学习参数的,它应该写成一个类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class MultiHeadAttention(nn.Module):

def __init__(self, heads: int, d_model: int, dropout: float = 0.1):
super().__init__()

assert d_model % heads == 0
# dk == dv
self.d_k = d_model // heads
self.heads = heads
self.d_model = d_model
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
# batch should be same
assert q.shape[0] == k.shape[0]
assert q.shape[0] == v.shape[0]
# the sequence length of k and v should be aligned
assert k.shape[1] == v.shape[1]

n, q_len = q.shape[0:2]
n, k_len = k.shape[0:2]
q_ = self.q(q).reshape(n, q_len, self.heads, self.d_k).transpose(1, 2)
k_ = self.k(k).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)
v_ = self.v(v).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)

attention_res = attention(q_, k_, v_, mask)
concat_res = attention_res.transpose(1, 2).reshape(
n, q_len, self.d_model)
concat_res = self.dropout(concat_res)

output = self.out(concat_res)
return output

这段代码一处很灵性的地方。在 Transformer 的论文中,多头注意力是先把每个词的表示拆成$h$个头,再对每份做投影、注意力,最后拼接起来,再投影一次。其实,拆开与拼接操作是多余的。我们可以通过一些形状上的操作,等价地实现拆开与拼接,以提高运行效率。

具体来说,我们可以一开始就让所有头的数据经过同一个线性层。之后在做注意力之前把头和序列数这两维转置一下。这两步操作和拆开来做投影、注意力是等价的。做完了注意力操作之后,再把两个维度转置回来,这和拼接操作是等价的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def forward(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
# batch should be same
assert q.shape[0] == k.shape[0]
assert q.shape[0] == v.shape[0]
# the sequence length of k and v should be aligned
assert k.shape[1] == v.shape[1]

n, q_len = q.shape[0:2]
n, k_len = k.shape[0:2]
q_ = self.q(q).reshape(n, q_len, self.heads, self.d_k).transpose(1, 2)
k_ = self.k(k).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)
v_ = self.v(v).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)

attention_res = attention(q_, k_, v_, mask)
concat_res = attention_res.transpose(1, 2).reshape(
n, q_len, self.d_model)
concat_res = self.dropout(concat_res)

output = self.out(concat_res)
return output

前馈网络

前馈网络太简单了,两个线性层,没什么好说的。注意内部那个隐藏层的维度大小$d_{ff}$会比$d_{model}$更大一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
class FeedForward(nn.Module):

def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.layer1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.layer2 = nn.Linear(d_ff, d_model)

def forward(self, x):
x = self.layer1(x)
x = self.dropout(F.relu(x))
x = self.layer2(x)
return x

Encoder & Decoder

准备好一切组件后,就可以把模型一层一层搭起来了。先搭好每个 Encoder 层和 Decoder 层,再拼成 Encoder 和 Decoder。

Encoder 层和 Decoder 层的结构与论文中的描述一致,且每个子层后面都有一个 dropout,和上一层之间使用了残差连接。归一化的方法是 LayerNorm。顺带一提,不仅是这些层,前面很多子层的计算中都加入了 dropout。

再提一句 mask。由于 encoder 和 decoder 的输入不同,它们的填充情况不同,产生的 mask 也不同。后文会展示这些 mask 的生成方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class EncoderLayer(nn.Module):

def __init__(self,
heads: int,
d_model: int,
d_ff: int,
dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(heads, d_model, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

def forward(self, x, src_mask: Optional[torch.Tensor] = None):
tmp = self.self_attention(x, x, x, src_mask)
tmp = self.dropout1(tmp)
x = self.norm1(x + tmp)
tmp = self.ffn(x)
tmp = self.dropout2(tmp)
x = self.norm2(x + tmp)
return x


class DecoderLayer(nn.Module):

def __init__(self,
heads: int,
d_model: int,
d_ff: int,
dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(heads, d_model, dropout)
self.attention = MultiHeadAttention(heads, d_model, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

def forward(self,
x,
encoder_kv: torch.Tensor,
dst_mask: Optional[torch.Tensor] = None,
src_dst_mask: Optional[torch.Tensor] = None):
tmp = self.self_attention(x, x, x, dst_mask)
tmp = self.dropout1(tmp)
x = self.norm1(x + tmp)
tmp = self.attention(x, encoder_kv, encoder_kv, src_dst_mask)
tmp = self.dropout2(tmp)
x = self.norm2(x + tmp)
tmp = self.ffn(x)
tmp = self.dropout3(tmp)
x = self.norm3(x + tmp)
return x

Encoder 和 Decoder 就是在所有子层前面加了一个嵌入层、一个位置编码,再把多个子层堆起来了而已,其他输入输出照搬即可。注意,我们可以给嵌入层输入pad_idx参数,让<pad>的计算不对梯度产生贡献。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Encoder(nn.Module):

def __init__(self,
vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 120):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
self.pe = PositionalEncoding(d_model, max_seq_len)
self.layers = []
for i in range(n_layers):
self.layers.append(EncoderLayer(heads, d_model, d_ff, dropout))
self.layers = nn.ModuleList(self.layers)
self.dropout = nn.Dropout(dropout)

def forward(self, x, src_mask: Optional[torch.Tensor] = None):
x = self.embedding(x)
x = self.pe(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, src_mask)
return x


class Decoder(nn.Module):

def __init__(self,
vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 120):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
self.pe = PositionalEncoding(d_model, max_seq_len)
self.layers = []
for i in range(n_layers):
self.layers.append(DecoderLayer(heads, d_model, d_ff, dropout))
self.layers = nn.Sequential(*self.layers)
self.dropout = nn.Dropout(dropout)

def forward(self,
x,
encoder_kv,
dst_mask: Optional[torch.Tensor] = None,
src_dst_mask: Optional[torch.Tensor] = None):
x = self.embedding(x)
x = self.pe(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, encoder_kv, dst_mask, src_dst_mask)
return x

Transformer 类

终于,激动人心的时候到来了。我们要把各个子模块组成变形金刚(Transformer)了。先过一遍所有的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Transformer(nn.Module):

def __init__(self,
src_vocab_size: int,
dst_vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 200):
super().__init__()
self.encoder = Encoder(src_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.decoder = Decoder(dst_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.pad_idx = pad_idx
self.output_layer = nn.Linear(d_model, dst_vocab_size)

def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):
# q_pad shape: [n, q_len]
# k_pad shape: [n, k_len]
# q_pad k_pad dtype: bool
assert q_pad.device == k_pad.device
n, q_len = q_pad.shape
n, k_len = k_pad.shape

mask_shape = (n, 1, q_len, k_len)
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)
mask = mask.to(q_pad.device)
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1
mask = mask.to(torch.bool)
return mask

def forward(self, x, y):

src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)
encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

我们一点一点来看。先看初始化函数。初始化函数的输入其实就是 Transformer 模型的超参数。总结一下,Transformer 应该有这些超参数:

  • d_model 模型中大多数词向量表示的维度大小
  • d_ff 前馈网络隐藏层维度大小
  • n_layers 堆叠的 Encoder & Decoder 层数
  • head 多头注意力的头数
  • dropout Dropout 的几率

另外,为了构建嵌入层,要知道源语言、目标语言的词典大小,并且提供pad_idx。为了预处理位置编码,需要提前知道一个最大序列长度。

照着子模块的初始化参数表,把参数归纳到__init__的参数表里即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def __init__(self,
src_vocab_size: int,
dst_vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 200):
super().__init__()
self.encoder = Encoder(src_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.decoder = Decoder(dst_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.pad_idx = pad_idx
self.output_layer = nn.Linear(d_model, dst_vocab_size)

再看一下 forward 函数。forward先预处理好了所有的 mask,再逐步执行 Transformer 的计算:先是通过 Encoder 获得源语言的中间表示encoder_kv,再把它和目标语言y的输入一起传入 Decoder,最后经过线性层输出结果res。由于 PyTorch 的交叉熵损失函数自带了 softmax 操作,这里不需要多此一举。

Transformer 论文提到,softmax 前的那个线性层可以和嵌入层共享权重。也就是说,嵌入和输出前的线性层分别完成了词序号到词嵌入的正反映射,两个操作应该是互逆的。但是,词嵌入矩阵不是一个方阵,它根本不能求逆矩阵。我想破头也没想清楚是怎么让线性层可以和嵌入层共享权重的。网上的所有实现都没有对这个细节多加介绍,只是新建了一个线性层。我也照做了。

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x, y):

src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)
encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

等了很久,现在可以来仔细看一看 mask 的生成方法了。回忆一下,表示该字符是否有效的 mask 有两个来源。第一个是论文里提到的,用于模拟串行推理的 mask;另一个是填充操作的空白字符引入的 mask。generate_mask 用于生成这些 mask。

generate_mask 的输入有 query 句子和 key 句子的 pad mask q_pad, k_pad,它们的形状为[n, seq_len]。若某处为 True,则表示这个地方的字符是<pad>。对于自注意力,query 和 key 都是一样的;而在 Decoder 的第二个多头注意力层中,query 来自目标语言,key 来自源语言。with_left_mask 表示是不是要加入 Decoder 里面的模拟串行推理的 mask,它会在掩码自注意力里用到。

1
2
3
4
def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):

一开始,先取好维度信息,定好张量的形状。在注意力操作中,softmax 前的那个量的形状是 [n, heads, q_len, k_len],表示每一批每一个头的每一个query对每个key之间的相似度。每一个头的mask是一样的。因此,除heads维可以广播外,mask 的形状应和它一样。

1
mask_shape = (n, 1, q_len, k_len)

再新建一个表示最终 mask 的张量。如果不用 Decoder 的那种 mask,就生成一个全零的张量;否则,生成一个上三角为0,其余地方为1的张量。注意,在我的代码中,mask 为 True 或1就表示这个地方需要填负无穷。

1
2
3
4
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)

最后,把有 <pad> 的地方也标记一下。从mask的形状[n, 1, q_len, k_len]可以知道,q_pad 表示哪些行是无效的,k_pad 表示哪些列是无效的。如果query句子的第i个字符是<pad>,则应该令mask[:, :, i, :] = 1; 如果key句子的第j个字符是<pad>,则应该令mask[:, :, :, j] = 1

下面的代码利用了PyTorch的取下标机制,直接并行地完成了mask赋值。

1
2
3
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):
# q_pad shape: [n, q_len]
# k_pad shape: [n, k_len]
# q_pad k_pad dtype: bool
assert q_pad.device == k_pad.device
n, q_len = q_pad.shape
n, k_len = k_pad.shape

mask_shape = (n, 1, q_len, k_len)
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)
mask = mask.to(q_pad.device)
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1
mask = mask.to(torch.bool)
return mask

看完了mask的生成方法后,我们回到前一步,看看mask会在哪些地方被调用。

在 Transformer 中,有三类多头注意力层,它们的 mask 也不同。Encoder 的多头注意力层的 query 和 key 都来自源语言;Decoder 的第一个多头注意力层的 query 和 key 都来自目标语言;Decoder 的第二个多头注意力层的 query 来自目标语言, key 来自源语言。另外,Decoder 的第一个多头注意力层要加串行推理的那个 mask。按照上述描述生成mask即可。

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x, y):
src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)

encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

到此,Transfomer 模型总算编写完成了。

这里再帮大家排一个坑。PyTorch的官方Transformer中使用了下面的参数初始化方式。但是,实际测试后,不知道为什么,我发现使用这种初始化会让模型训不起来。

1
2
3
4
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

我去翻了翻PyTorch的Transformer示例,发现官方的示例根本没用到Transformer,而是用子模块nn.TransformerDecoder, nn.TransformerEncoder自己搭了一个新的Transformer。这些子模块其实都有自己的init_weights方法。看来官方都信不过自己的Transformer,这个Transformer类的初始化方法就有问题。

在我们的代码中,我们不必手动对参数初始化。PyTorch对每个线性层默认的参数初始化方式就够好了。

训练

准备好了模型、数据集后,剩下的工作非常惬意,只要随便调用一下就行了。训练的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import time

from dldemos.Transformer.data_load import (get_batch_indices, load_cn_vocab,
load_en_vocab, load_train_data,
maxlen)
from dldemos.Transformer.model import Transformer

# Config
batch_size = 64
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60
PAD_ID = 0


def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()
# X: en
# Y: cn
Y, X = load_train_data()

print_interval = 100

model = Transformer(len(en2idx), len(cn2idx), PAD_ID, d_model, d_ff,
n_layers, heads, dropout_rate, maxlen)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)

citerion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
tic = time.time()
cnter = 0
for epoch in range(n_epochs):
for index, _ in get_batch_indices(len(X), batch_size):
x_batch = torch.LongTensor(X[index]).to(device)
y_batch = torch.LongTensor(Y[index]).to(device)
y_input = y_batch[:, :-1]
y_label = y_batch[:, 1:]
y_hat = model(x_batch, y_input)

y_label_mask = y_label != PAD_ID
preds = torch.argmax(y_hat, -1)
correct = preds == y_label
acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask)

n, seq_len = y_label.shape
y_hat = torch.reshape(y_hat, (n * seq_len, -1))
y_label = torch.reshape(y_label, (n * seq_len, ))
loss = citerion(y_hat, y_label)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()

if cnter % print_interval == 0:
toc = time.time()
interval = toc - tic
minutes = int(interval // 60)
seconds = int(interval % 60)
print(f'{cnter:08d} {minutes:02d}:{seconds:02d}'
f' loss: {loss.item()} acc: {acc.item()}')
cnter += 1

model_path = 'dldemos/Transformer/model.pth'
torch.save(model.state_dict(), model_path)

print(f'Model saved to {model_path}')


if __name__ == '__main__':
main()

所有的超参数都写在代码开头。在模型结构上,我使用了和原论文一样的超参数。

1
2
3
4
5
6
7
8
9
10
# Config
batch_size = 64
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60
PAD_ID = 0

之后,进入主函数。一开始,我们调用load_data.py提供的API,获取中英文序号到单词的转换词典,并获取已经打包好的训练数据。

1
2
3
4
5
6
7
def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()
# X: en
# Y: cn
Y, X = load_train_data()

接着,我们用参数初始化好要用到的对象,比如模型、优化器、损失函数。

1
2
3
4
5
6
7
8
9
10
11
print_interval = 100

model = Transformer(len(en2idx), len(cn2idx), PAD_ID, d_model, d_ff,
n_layers, heads, dropout_rate, maxlen)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)

citerion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
tic = time.time()
cnter = 0

再然后,进入训练循环。我们从X, Y里取出源语言和目标语言的序号数组,输入进模型里。别忘了,Transformer可以并行训练。我们给模型输入目标语言前n-1个单词,用第2到第n个单词作为监督标签。

1
2
3
4
5
6
7
for epoch in range(n_epochs):
for index, _ in get_batch_indices(len(X), batch_size):
x_batch = torch.LongTensor(X[index]).to(device)
y_batch = torch.LongTensor(Y[index]).to(device)
y_input = y_batch[:, :-1]
y_label = y_batch[:, 1:]
y_hat = model(x_batch, y_input)

得到模型的预测y_hat后,我们可以把输出概率分布中概率最大的那个单词作为模型给出的预测单词,算一个单词预测准确率。当然,我们要排除掉<pad>的影响。

1
2
3
4
y_label_mask = y_label != PAD_ID
preds = torch.argmax(y_hat, -1)
correct = preds == y_label
acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask)

我们最后算一下loss,并执行梯度下降,训练代码就写完了。为了让训练更稳定,不出现梯度过大的情况,我们可以用torch.nn.utils.clip_grad_norm_(model.parameters(), 1)裁剪梯度。

1
2
3
4
5
6
7
8
9
n, seq_len = y_label.shape
y_hat = torch.reshape(y_hat, (n * seq_len, -1))
y_label = torch.reshape(y_label, (n * seq_len, ))
loss = citerion(y_hat, y_label)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()

实验

在本项目的实验中,使用单卡3090,约10分钟就能完成训练。最终的训练准确率可以到达90%以上。

1
00006300 12:12 loss: 0.43494755029678345 acc: 0.9049844145774841

该数据集没有提供测试集(原仓库里的测试集来自训练集,这显然不合理)。且由于词表太小,不太好构建测试集。因此,我没有编写从测试集里生成句子并算BLEU score的代码,而是写了一份翻译给定句子的代码。要编写测试BLUE score的代码,只需要把翻译任意句子的代码改个输入,加一个求BLEU score的函数即可。这份翻译任意句子的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch

from dldemos.Transformer.data_load import (load_cn_vocab, load_en_vocab,
idx_to_sentence, maxlen)
from dldemos.Transformer.model import Transformer

# Config
batch_size = 1
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60

PAD_ID = 0


def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()

model = Transformer(len(en2idx), len(cn2idx), 0, d_model, d_ff, n_layers,
heads, dropout_rate, maxlen)
model.to(device)
model.eval()

model_path = 'dldemos/Transformer/model.pth'
model.load_state_dict(torch.load(model_path))

my_input = ['we', "should", "protect", "environment"]
x_batch = torch.LongTensor([[en2idx[x] for x in my_input]]).to(device)

cn_sentence = idx_to_sentence(x_batch[0], idx2en, True)
print(cn_sentence)

y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)


if __name__ == '__main__':
main()

一开始,还是先获取词表,并初始化模型。

1
2
3
4
5
6
7
8
9
10
11
12
def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()

model = Transformer(len(en2idx), len(cn2idx), 0, d_model, d_ff, n_layers,
heads, dropout_rate, maxlen)
model.to(device)
model.eval()

model_path = 'dldemos/Transformer/model.pth'
model.load_state_dict(torch.load(model_path))

之后,我们用自己定义的句子(要做好分词)代替原来的输入x_batch。如果要测试某个数据集,只要把这里x_batch换成测试集里的数据即可。
我们可以顺便把序号数组用idx_to_sentence转回英文,看看序号转换有没有出错。

1
2
3
4
5
my_input = ['we', "should", "protect", "environment"]
x_batch = torch.LongTensor([[en2idx[x] for x in my_input]]).to(device)

cn_sentence = idx_to_sentence(x_batch[0], idx2en, True)
print(cn_sentence)

这段代码会输出we should protect environment。这说明x_batch是我们想要的序号数组。

最后,我们利用Transformer自回归地生成句子,并输出句子。

1
2
3
4
5
6
7
8
9
10
11
y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)

要自回归地生成句子,我们先给句子填入无效字符<pad>,再把第一个字符换成句子开始字符<S>

1
2
3
y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']

之后,我们循环调用Transformer,获取下一个单词的概率分布。我们可以认为,概率最大的那个单词就是模型预测的下一个单词。因此,我们可以用argmax获取预测的下一个单词的序号,填回y_input。这里的y_input和训练时那个y_batch是同一个东西。

1
2
3
4
5
6
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])

最后只要输出生成的句子即可。

1
2
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)

由于训练数据非常少,而且数据都来自新闻,我只好选择了一个比较常见的句子”we should protect environment”作为输入。模型翻译出了一个比较奇怪的结果。

1
<S> 要 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 的 生态 环境 落实 好 环境 </S> 环境 </S> 有效 保护 环境 </S>...

可以看出,模型确实学到了东西,能翻译出“要保护环境”。但是,这翻译的结果也太长太奇怪了。感觉是对训练数据过拟合了。当然,还是那句话,训练集里的数据太少。要提升模型性能并缓解过拟合,加数据集是最好的方法。这个结果起码说明我们Tranformer的编写没有问题。

在生成新句子的时候,我直接拿概率最高的单词当做预测的下一个单词。其实,还有一些更加高级的生成算法,比如Beam Search。如果模型训练得比较好,可以用这些高级一点的算法提高生成句子的质量。

我读了网上几份Transformer实现。这些实现在生成句子算BLEU score时,竟然直接输入测试句子的前n-1个单词,把输出的n-1个单词拼起来,作为模型的翻译结果。这个过程等价于告诉你前i个翻译答案,你去输出第i+1个单词,再把每个结果拼起来。这样写肯定是不合理的。正常来说应该是照着我这样自回归地生成翻译句子。大家参考网上的Transformer代码时要多加留心。

总结

只要读懂了 Transfomer 的论文,用 PyTorch 实现一遍 Transformer 是很轻松的。但是,代码实现中有非常多论文不会提及的细节,你自己实现时很容易踩坑。在这篇文章里,我完整地介绍了一个英中翻译 Transformer 的 PyTorch 实现,相信读者能够跟随这篇文章实现自己的 Transformer,并在代码实现的过程中加深对论文的理解。

再稍微总结一下代码实现中的一些值得注意的地方。代码中最大的难点是 mask 的实现。mask 的处理稍有闪失,就可能会让计算结果中遍布 NaN。一定要想清楚各个模块的 mask 是从哪来的,它们在注意力计算里是怎么被用上的。

另外,有两处地方的实现比较灵活。一处是位置编码的实现,一处是多头注意力中怎么描述“多头”。其他模块的实现都大差不差,千篇一律。

最后再提醒一句,要从头训练一个模型,一定要从小数据集上开始做。不然你训练个半天,结果差了,你不知道是数据有问题,还是代码有问题。我之前一直在使用很大的训练集,每次调试都非常麻烦,浪费了很多时间。希望大家引以为戒。

参考资料

感谢 https://github.com/P3n9W31/transformer-pytorch 提供的数据集。

一份简明的Transformer实现代码 https://github.com/hyunwoongko/transformer

一篇不错的Transformer实现教程 https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec

过期内容

我第一次写这篇文章时过于仓促,文章中有不少错误,实验部分也没写完。我后来把本文又重新修改了一遍,补充了实验部分。

我之前使用了一个较大的数据集,但发现做实验做得很慢,于是换了一个较小的数据集。以前的数据集预处理介绍就挪到这里了。

数据集与评测方法

在开启一个深度学习项目之初,要把任务定义好。准确来说,我们要明白这个任务是在完成一个怎样的映射,并准备一个用于评测的数据集,定义好评价指标。

英中翻译,这个任务非常明确,就是把英文的句子翻译成中文。英中翻译的数据集应该包含若干个句子对,每个句子对由一句英文和它对应的中文翻译组成。

中英翻译的数据集不是很好找。有几个比较出名的数据集的链接已经失效了,还有些数据集需要注册与申请后才能获取。我在中文NLP语料库仓库(https://github.com/brightmart/nlp_chinese_corpus)找到了中英文平行语料 translation2019zh。该语料库由520万对中英文语料构成,训练集516万对,验证集3.9万对。用作训练和验证中英翻译模型是足够了。

机器翻译的评测指标叫做BLEU Score。如果模型输出的翻译和参考译文有越多相同的单词、连续2个相同单词、连续3个相同单词……,则得分越高。

PyTorch 提供了便捷的API,我们可以用一行代码算完BLEU Score。

1
2
3
4
5
>>> from torchtext.data.metrics import bleu_score
>>> candidate_corpus = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']]
>>> references_corpus = [[['My', 'full', 'pytorch', 'test'], ['Completely', 'Different']], [['No', 'Match']]]
>>> bleu_score(candidate_corpus, references_corpus)
0.8408964276313782

数据清洗

得到数据集后,下一步要做的是对数据集做处理,把原始数据转化成能够输入神经网络的张量。对于图片,预处理可能是裁剪、缩放,使所有图片都有一样的大小;对于文本,预处理可能是分词、填充。

网盘上下载好 translation2019zh 数据集后,我们来一步一步清洗这个数据集。这个数据集只有两个文件translation2019zh_train.json, translation2019zh_valid.json,它们的结构如下:

text
1
2
3
4
{"english": <english>, "chinese": <chinese>}
{"english": <english>, "chinese": <chinese>}
{"english": <english>, "chinese": <chinese>}
...

这些json文件有点不合标准,每对句子由一行json格式的记录组成。english属性是英文句子,chinese属性是中文句子。比如:

text
1
{"english": "In Italy ...", "chinese": "在意大利 ..."}

因此,在读取数据时,我们可以用下面的代码提取每对句子。

1
2
3
4
5
6
import json

with open(json_path, 'r') as fp:
for line in fp:
line = json.loads(line)
english, chinese = line['english'], line['chinese']

这个数据集有一点不干净,有一些句子对的中英文句子颠倒过来了。为此,我们要稍微处理一下,把这些句子对翻转过来。如果一个英文句子不全由 ASCII 组成,则它可能是一个被标错的中文句子。

1
2
3
# Correct mislabeled data
if not english.isascii():
english, chinese = chinese, english

经过这一步,我们只得到了中英文的字符文本。而在NLP中,大部分处理的最小单位都是符号(token)——对于英文来说,符号是单词、标点;对于中文来说,符号是词语、标点。我们还需要一个符号化的过程。

英文符号化非常方便,torchtext 提供了非常便捷的英文分词 API。

1
2
3
4
from torchtext.data import get_tokenizer

tokenizer = get_tokenizer('basic_english')
english = tokenizer(english)

而中文分词方面,我使用了jieba库。该库可以直接 pip 安装。

1
pip install jieba

分词的 API 是 jieba.cut。由于分词的结果中,相邻的词之间有空格,我一股脑地把所有空白符给过滤掉了。

1
2
3
import jieba
chinese = list(jieba.cut(chinese))
chinese = [x for x in chinese if x not in {' ', '\t'}]

经过这些处理后,每句话被转换成了中文词语或英文单词的数组。整个处理代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def read_file(json_path):
english_sentences = []
chinese_sentences = []
tokenizer = get_tokenizer('basic_english')
with open(json_path, 'r') as fp:
for line in fp:
line = json.loads(line)
english, chinese = line['english'], line['chinese']
# Correct mislabeled data
if not english.isascii():
english, chinese = chinese, english
# Tokenize
english = tokenizer(english)
chinese = list(jieba.cut(chinese))
chinese = [x for x in chinese if x not in {' ', '\t'}]
english_sentences.append(english)
chinese_sentences.append(chinese)
return english_sentences, chinese_sentences

词语转序号

为了让计算机更方便地处理单词,我们还要把单词转换成序号。比如令apple为0号,banana为1号,则句子apple banana apple就转换成了0 1 0

给每一个单词选一个标号,其实就是要建立一个词典。一般来说,我们可以利用他人的统计结果,挑选最常用的一些英文单词和中文词语构成词典。不过,现在我们已经有了一个庞大的中英语料库了,我们可以直接从这个语料库中挑选出最常见的词构成词典。

根据上一步处理得到的句子数组sentences,我们可以用下面的 Python 代码统计出最常见的一些词语,把它们和4个特殊字符<sos>, <eos>, <unk>, <pad>(句子开始字符、句子结束字符、频率太少没有被加入词典的词语、填充字符)一起构成词典。统计字符出现次数是通过 Python 的 Counter 类实现的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from collections import Counter

def create_vocab(sentences, max_element=None):
"""Note that max_element includes special characters"""

default_list = ['<sos>', '<eos>', '<unk>', '<pad>']

char_set = Counter()
for sentence in sentences:
c_set = Counter(sentence)
char_set.update(c_set)

if max_element is None:
return default_list + list(char_set.keys())
else:
max_element -= 4
words_freq = char_set.most_common(max_element)
# pair array to double array
words, freq = zip(*words_freq)
return default_list + list(words)

准备好了词典后,我还编写了两个工具函数sentence_to_tensortensor_to_sentence,它们可以用于字符串数组与序号数组的互相转换。测试这些代码的脚本及其输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Dataset.py

def main():
en_sens, zh_sens = read_file(
'data/translation2019zh/translation2019zh_valid.json')
print(*en_sens[0:3])
print(*zh_sens[0:3])
en_vocab = create_vocab(en_sens, 10000)
zh_vocab = create_vocab(zh_sens, 30000)
print(list(en_vocab)[0:10])
print(list(zh_vocab)[0:10])

en_tensors = sentence_to_tensor(en_sens, en_vocab)
zh_tensors = sentence_to_tensor(zh_sens, zh_vocab)

print(tensor_to_sentence(en_tensors[0], en_vocab, True))
print(tensor_to_sentence(zh_tensors[0], zh_vocab))
text
1
2
3
4
5
6
['slowly', 'and', 'not', 'without', 'struggle', ',', 'america', 'began', 'to', 'listen', '.'] ...]
['美国', '缓慢', '地', '开始', '倾听', ',', '但', '并非', '没有', '艰难曲折', '。'] ...]
['<sos>', '<eos>', '<unk>', '<pad>', 'the', '.', ',', 'of', 'and', 'to']
['<sos>', '<eos>', '<unk>', '<pad>', '的', ',', '。', '在', '了', '和']
slowly and not without struggle , america began to listen .
美国缓慢地开始倾听,但并非没有<unk>。

在这一步中,有一个重要的参数:词典的大小。显然,词典越大,能处理的词语越多,但训练速度也会越慢。由于这个项目只是一个用于学习的demo,我设置了比较小的词典大小。想提升整个模型的性能的话,调大词典大小是一个最快的方法。

生成 Dataloader

都说程序员是新时代的农民工,这非常有道理。因为,作为程序员,你免不了要写一些繁重、无聊的数据处理脚本。还好,写完这些无聊的预处理代码后,总算可以使用 PyTorch 的 API 写一些有趣的代码了。

把词语数组转换成序号句子数组后,我们要考虑怎么把序号句子数组输入给模型了。文本数据通常长短不一,为了一次性处理一个 batch 的数据,要把短的句子填充,使得一批句子长度相等。写 Dataloader 时最主要的工作就是填充并对齐句子。

先看一下Dataset的写法。上一步得到的序号句子数组可以塞进Dataset里。注意,每个句子的前后要加上表示句子开始和结束的特殊符号。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
SOS_ID = 0
EOS_ID = 1
UNK_ID = 2
PAD_ID = 3

class TranslationDataset(Dataset):

def __init__(self, en_tensor: np.ndarray, zh_tensor: np.ndarray):
super().__init__()
assert len(en_tensor) == len(zh_tensor)
self.length = len(en_tensor)
self.en_tensor = en_tensor
self.zh_tensor = zh_tensor

def __len__(self):
return self.length

def __getitem__(self, index):
x = np.concatenate(([SOS_ID], self.en_tensor[index], [EOS_ID]))
x = torch.from_numpy(x)
y = np.concatenate(([SOS_ID], self.zh_tensor[index], [EOS_ID]))
y = torch.from_numpy(y)
return x, y

接下来看一下 DataLoader 的写法。在创建 Dataloader 时,最重要的是 collate_fn 的编写,这个函数决定了怎么把多条数据合成一个等长的 batch。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_dataloader(en_tensor: np.ndarray,
zh_tensor: np.ndarray,
batch_size=16):

def collate_fn(batch):
...

dataset = TranslationDataset(en_tensor, zh_tensor)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)

return dataloader

collate_fn 的输入是多个 dataset __getitem__ 的返回结果构成的数组。对于我们的 dataset 来说,collate_fn 的输入是 [(x1, y1), (x2, y2), ...] 。我们可以用 zip(*batch) 把二元组数组拆成两个数组 x, y

collate_fn 的输出就是将来 dataloader 的输出。PyTorch 提供了 pad_sequence 函数用来把一批数据填充至等长。

1
2
3
4
5
6
7
8
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
x, y = zip(*batch)
x_pad = pad_sequence(x, batch_first=True, padding_value=PAD_ID)
y_pad = pad_sequence(y, batch_first=True, padding_value=PAD_ID)

return x_pad, y_pad

实现完collate_fn后,我们就可以得到了DataLoader。这样,数据集预处理部分大功告成。

近两年,有许多图像生成类任务的前沿工作都使用了一种叫做”codebook”的机制。追溯起来,codebook机制最早是在VQ-VAE论文中提出的。相比于普通的VAE,VQ-VAE能利用codebook机制把图像编码成离散向量,为图像生成类任务提供了一种新的思路。VQ-VAE的这种建模方法启发了无数的后续工作,包括声名远扬的Stable Diffusion。

在这篇文章中,我将先以易懂的逻辑带领大家一步一步领悟VQ-VAE的核心思想,再介绍VQ-VAE中关键算法的具体形式,最后把VQ-VAE的贡献及其对其他工作的影响做一个总结。通过阅读这篇文章,你不仅能理解VQ-VAE本身的原理,更能知道如何将VQ-VAE中的核心机制活学活用。

从 AE 到 VQ-VAE

为什么VQ-VAE想要把图像编码成离散向量?让我们从最早的自编码器(Autoencoder, AE)开始一步一步谈起。AE是一类能够把图片压缩成较短的向量的神经网络模型,其结构如下图所示。AE包含一个编码器$e()$和一个解码器$d()$。在训练时,输入图像$\mathbf{x}$会被编码成一个较短的向量$\mathbf{z}$,再被解码回另一幅长得差不多的图像$\hat{\mathbf{x}}$。网络的学习目标是让重建出来的图像$\hat{\mathbf{x}}$和原图像$\mathbf{x}$尽可能相似。

解码器可以把一个向量解码成图片。换一个角度看,解码器就是一个图像生成模型,因为它可以根据向量来生成图片。那么,AE可不可以用来做图像生成呢?很可惜,AE的编码器编码出来的向量空间是不规整的。也就是说,解码器只认识经编码器编出来的向量,而不认识其他的向量。如果你把自己随机生成出来的向量输入给解码器,解码器是生成不出有意义的图片的。AE不能够随机生成图片,所以它不能很好地完成图像生成任务,只能起到把图像压缩的作用。

AE离图像生成只差一步了。只要AE的编码空间比较规整,符合某个简单的数学分布(比如最常见的标准正态分布),那我们就可以从这个分布里随机采样向量,再让解码器根据这个向量来完成随机图片生成了。VAE就是这样一种改进版的AE。它用一些巧妙的方法约束了编码向量$\mathbf{z}$,使得$\mathbf{z}$满足标准正态分布。这样,解码器不仅认识编码器编出的向量,还认识其他来自标准正态分布的向量。训练完成后,我们就可以扔掉编码器,用来自标准正态分布的随机向量和解码器来实现随机图像生成了。

VAE的实现细节就不在这里赘述了,是否理解它对理解VQ-VAE没有影响。我们只需知道VAE可以把图片编码成符合标准正态分布的向量即可。让向量符合标准正态分布的原因是方便随机采样。同时,需要强调的是,VAE编码出来的向量是连续向量,也就是向量的每一维都是浮点数。如果把向量的某一维稍微改动0.0001,解码器还是认得这个向量,并且会生成一张和原向量对应图片差不多的图片。

但是,VAE生成出来的图片都不是很好看。VQ-VAE的作者认为,VAE的生成图片之所以质量不高,是因为图片被编码成了连续向量。而实际上,把图片编码成离散向量会更加自然。比如我们想让画家画一个人,我们会说这个是男是女,年龄是偏老还是偏年轻,体型是胖还是壮,而不会说这个人性别是0.5,年龄是0.6,体型是0.7。因此,VQ-VAE会把图片编码成离散向量,如下图所示。

把图像编码成离散向量后,又会带来两个新的问题。第一个问题是,神经网络会默认输入满足一个连续的分布,而不善于处理离散的输入。如果你直接输入0, 1, 2这些数字,神经网络会默认1是一个处于0, 2中间的一种状态。为了解决这一问题,我们可以借鉴NLP中对于离散单词的处理方法。为了处理离散的输入单词,NLP模型的第一层一般都是词嵌入层,它可以把每个输入单词都映射到一个独一无二的连续向量上。这样,每个离散的数字都变成了一个特别的连续向量了。

我们可以把类似的嵌入层加到VQ-VAE的解码器前。这个嵌入层在VQ-VAE里叫做”embedding space(嵌入空间)”,在后续文章中则被称作”codebook”。

离散向量的另一个问题是它不好采样。回忆一下,VAE之所以把图片编码成符合正态分布的连续向量,就是为了能在图像生成时把编码器扔掉,让随机采样出的向量也能通过解码器变成图片。现在倒好,VQ-VAE把图片编码了一个离散向量,这个离散向量构成的空间是不好采样的。VQ-VAE不是面临着和AE一样的问题嘛。

这个问题是无解的。没错!VQ-VAE根本不是一个图像生成模型。它和AE一样,只能很好地完成图像压缩,把图像变成一个短得多的向量,而不支持随机图像生成。VQ-VAE和AE的唯一区别,就是VQ-VAE会编码出离散向量,而AE会编码出连续向量。

可为什么VQ-VAE会被归类到图像生成模型中呢?这是因为VQ-VAE的作者利用VQ-VAE能编码离散向量的特性,使用了一种特别的方法对VQ-VAE的离散编码空间采样。VQ-VAE的作者之前设计了一种图像生成网络,叫做PixelCNN。PixelCNN能拟合一个离散的分布。比如对于图像,PixelCNN能输出某个像素的某个颜色通道取0~255中某个值的概率分布。这不刚好嘛,VQ-VAE也是把图像编码成离散向量。换个更好理解的说法,VQ-VAE能把图像映射成一个「小图像」。我们可以把PixelCNN生成图像的方法搬过来,让PixelCNN学习生成「小图像」。这样,我们就可以用PixelCNN生成离散编码,再利用VQ-VAE的解码器把离散编码变成图像。

让我们来整理一下VQ-VAE的工作过程。

  1. 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成「小图像」,也能把「小图像」变回图像。
  2. 训练PixelCNN,让它学习怎么生成「小图像」。
  3. 随机采样时,先用PixelCNN采样出「小图像」,再用VQ-VAE把「小图像」翻译成最终的生成图像。

到这里,我们已经学完了VQ-VAE的核心思想。让我们来总结一下。VQ-VAE不是一个VAE,而是一个AE。它的目的是把图像压缩成离散向量。或者换个角度说,它提供了把大图像翻译成「小图像」的方法,也提供了把「小图像」翻译成大图像的方法。这样,一个随机生成大图像的问题,就被转换成了一个等价的随机生成一个较小的「图像」的问题。有一些图像生成模型,比如PixelCNN,更适合拟合离散分布。可以用它们来完成生成「小图像」的问题,填补上VQ-VAE生成图片的最后一片空缺。

VQ-VAE 设计细节

在上一节中,我们虽然认识了VQ-VAE的核心思想,但略过了不少实现细节,比如:

  • VQ-VAE的编码器怎么输出离散向量。
  • VQ-VAE怎么优化编码器和解码器。
  • VQ-VAE怎么优化嵌入空间。

在这一节里,我们来详细探究这些细节。

输出离散编码

想让神经网络输出一个整数,最简单的方法是和多分类模型一样,输出一个Softmax过的概率分布。之后,从概率分布里随机采样一个类别,这个类别的序号就是我们想要的整数。比如在下图中,我们想得到一个由3个整数构成的离散编码,就应该让编码器输出3组logit,再经过Softmax与采样,得到3个整数。

但是,这么做不是最高效的。得到离散编码后,下一步我们又要根据嵌入空间把离散编码转回一个向量。可见,获取离散编码这一步有一点多余。能不能把编码器的输出张量(它之前的名字叫logit)、解码器的输入张量embedding、嵌入空间直接关联起来呢?

VQ-VAE使用了如下方式关联编码器的输出与解码器的输入:假设嵌入空间已经训练完毕,对于编码器的每个输出向量$z_e(x)$,找出它在嵌入空间里的最近邻$z_q(x)$,把$z_e(x)$替换成$z_q(x)$作为解码器的输入。

求最近邻,即先计算向量与嵌入空间$K$个向量每个向量的距离,再对距离数组取一个argmin,求出最近的下标(比如图中的0, 1, 1),最后用下标去嵌入空间里取向量。下标构成的数组(比如图中的[0, 1, 1])也正是VQ-VAE的离散编码。

就这样,我们知道了VQ-VAE是怎么生成离散编码的。VQ-VAE的编码器其实不会显式地输出离散编码,而是输出了多个「假嵌入」$z_e(x)$。之后,VQ-VAE对每个$z_e(x)$在嵌入空间里找最近邻,得到真正的嵌入$z_q(x)$,把$z_q(x)$作为解码器的输入。

虽然我们现在能把编码器和解码器拼接到一起,但现在又多出了一个问题:怎么让梯度从解码器的输入$z_q(x)$传到$z_e(x)$?从$z_e(x)$到$z_q(x)$的变换是一个从数组里取值的操作,这个操作是求不了导的。我们在下一小节里来详细探究一下怎么优化VQ-VAE的编码器和解码器。

优化编码器和解码器

为了优化编码器和解码器,我们先来制订一下VQ-VAE的整体优化目标。由于VQ-VAE其实是一个AE,误差函数里应该只有原图像和目标图像的重建误差。

或者非要从VAE的角度说也行。VQ-VAE相当于输出了一个one-hot离散分布。假设输入图像$x$的离散编码$z$是$k$,则分布中仅有$q(z=k|x)=1$,$q(z=others|x)=0$。令离散编码$z$的先验分布是均匀分布(假设不知道输入图像$x$,每个离散编码取到的概率是等同的),则先验分布$q(z)$和后验分布$q(z|x)$的KL散度是常量。因此,KL散度项不用算入损失函数里。理解此处的数学推导意义不大,还不如直接理解成VQ-VAE其实是一个AE。

但直接拿这个误差来训练是不行的。误差中,$z_q(x)$是解码器的输入。从编码器输出$z_e(x)$到$z_q(x)$这一步是不可导的,误差无法从解码器传递到编码器上。要是可以把$z_q(x)$的梯度直接原封不动地复制到$z_e(x)$上就好了。

VQ-VAE使用了一种叫做”straight-through estimator”的技术来完成梯度复制。这种技术是说,前向传播和反向传播的计算可以不对应。你可以为一个运算随意设计求梯度的方法。基于这一技术,VQ-VAE使用了一种叫做$sg$(stop gradient,停止梯度)的运算:

也就是说,前向传播时,$sg$里的值不变;反向传播时,$sg$按值为0求导,即此次计算无梯度。(反向传播其实不会用到式子的值,只会用到式子的梯度。反向传播用到的loss值是在前向传播中算的)。

基于这种运算,我们可以设计一个把梯度从$z_e(x)$复制到$z_q(x)$的误差:

也就是说,前向传播时,就是拿解码器输入$z_q(x)$来算梯度。

而反向传播时,按下面这个公式求梯度,等价于把解码器的梯度全部传给$z_e(x)$。

这部分的PyTorch实现如下所示。在PyTorch里,(x).detach()就是$sg(x)$,它的值在前向传播时取x,反向传播时取0

1
L = x - decoder(z_e + (z_q - z_e).detach())

通过这一技巧,我们完成了梯度的传递,可以正常地训练编码器和解码器了。

优化嵌入空间

到目前为止,我们的讨论都是建立在嵌入空间已经训练完毕的前提上的。现在,我们来讨论一下嵌入空间的训练方法。

嵌入空间的优化目标是什么呢?嵌入空间的每一个向量应该能概括一类编码器输出的向量,比如一个表示「青年」的向量应该能概括所有14-35岁的人的照片的编码器输出。因此,嵌入空间的向量应该和其对应编码器输出尽可能接近。如下面的公式所示,$z_e(x)$是编码器的输出向量,$z_q(x)$是其在嵌入空间的最近邻向量。

但作者认为,编码器和嵌入向量的学习速度应该不一样快。于是,他们再次使用了停止梯度的技巧,把上面那个误差函数拆成了两部分。其中,$\beta$控制了编码器的相对学习速度。作者发现,算法对$\beta$的变化不敏感,$\beta$取0.1~2.0都差不多。

其实,在论文中,作者分别讨论了上面公式里的两个误差。第一个误差来自字典学习算法里的经典算法Vector Quantisation(VQ),也就是VQ-VAE里的那个VQ,它用于优化嵌入空间。第二个误差叫做专注误差,它用于约束编码器的输出,不让它跑到离嵌入空间里的向量太远的地方。

这样,VQ-VAE总体的损失函数可以写成:(由于算上了重建误差,我们多加一个$\alpha$用于控制不同误差之间的比例)

总结

VQ-VAE是一个把图像编码成离散向量的图像压缩模型。为了让神经网络理解离散编码,VQ-VAE借鉴了NLP的思想,让每个离散编码值对应一个嵌入,所有的嵌入都存储在一个嵌入空间(又称”codebook”)里。这样,VQ-VAE编码器的输出是若干个「假嵌入」,「假嵌入」会被替换成嵌入空间里最近的真嵌入,输入进解码器里。

VQ-VAE的优化目标由两部分组成:重建误差和嵌入空间误差。重建误差为输入图片和重建图片的均方误差。为了让梯度从解码器传到编码器,作者使用了一种巧妙的停止梯度算子,让正向传播和反向传播按照不同的方式计算。嵌入空间误差为嵌入和其对应的编码器输出的均方误差。为了让嵌入和编码器以不同的速度优化,作者再次使用了停止梯度算子,把嵌入的更新和编码器的更新分开计算。

训练完成后,为了实现随机图像生成,需要对VQ-VAE的离散分布采样,再把采样出来的离散向量对应的嵌入输入进解码器。VQ-VAE论文使用了PixelCNN来采样离散分布。实际上,PixelCNN不是唯一一种可用的拟合离散分布的模型。我们可以把它换成Transformer,甚至是diffusion模型。如果你当年看完VQ-VAE后立刻把PixelCNN换成了diffusion模型,那么恭喜你,你差不多提前设计出了Stable Diffusion。

可见,VQ-VAE最大的贡献是提供了一种图像压缩思路,把生成大图像的问题转换成了一个更简单的生成「小图像」的问题。图像压缩成离散向量时主要借助了嵌入空间,或者说”codebook”这一工具。这种解决问题的思路可以应用到所有图像生成类任务上,比如超分辨率、图像修复、图像去模糊等。所以近两年我们能看到很多使用了codebook的图像生成类工作。

参考资料

PixelCNN的介绍可以参见我之前的文章:详解PixelCNN大家族。

VQ-VAE的论文为Neural Discrete Representation Learning。这篇文章不是很好读懂,建议直接读我的这篇解读。再推荐另一份还不错的中文解读 https://www.spaces.ac.cn/archives/6760。

图像生成是一个较难建模的任务。为此,我们要用GAN、VAE、Diffusion等精巧的架构来建模图像生成。可是,在NLP中,文本生成却有一种非常简单的实现方法。NLP中有一种基础的概率模型——N元语言模型。N元语言模型可以根据句子的前几个字预测出下一个字的出现概率。比如看到「我爱吃苹……」这句话的前几个字,我们不难猜出下一个字大概率是「果」字。利用N元语言模型,我们可以轻松地实现一个文本生成算法:输入空句子,采样出第一个字;输入第一个字,采样出第二个字;输入前两个字,输出第三个字……以此类推。

既然如此,我们可不可以把相同的方法搬到图像生成里呢?当然可以。虽然图像是二维的数据,不像一维的文本一样有先后顺序,但是我们可以强行给图像的每个像素规定一个顺序。比如,我们可以从左到右,从上到下地给图像标上序号。这样,从逻辑上看,图像也是一个一维数据,可以用NLP中的方法来按照序号实现图像生成了。

PixelCNN就是一个使用这种方法生成图像的模型。可为什么PixelCNN的名气没有GAN、VAE那么大?为什么PixelCNN可以用CNN而不是RNN来处理一维化图像?为什么PixelCNN是一种「自回归模型」?别急,在这篇文章中,我们将认识PixelCNN及其改进模型Gated PixelCNN和PixelCNN++,并认真学习它们的实现代码。看完文章后,这些问题都会迎刃而解。

PixelCNN

如前所述,PixelCNN借用了NLP里的方法来生成图像。模型会根据前i - 1个像素输出第i个像素的概率分布。训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数;采样时,直接从预测的概率分布里采样出第i个像素。根据这些线索,我们来尝试自己「发明」一遍PixelCNN。

这种模型最朴素的实现方法,是输入一幅图像的前i - 1个像素,输出第i个像素的概率分布,即第i个像素取某种颜色的概率的数组。为了方便讨论,我们先只考虑单通道图像,每个像素的颜色取值只有256种。因此,准确来说,模型的输出是256个经过softmax的概率。这样,我们得到了一个V1.0版本的模型。

等等,模型不是叫「PixelCNN」吗?CNN跑哪去了?的确,对于图像数据,最好还是使用CNN,快捷又有效。因此,我们应该修改模型,令模型的输入为整幅图像和序号i。我们根据序号i,过滤掉ii之后的像素,用CNN处理图像。输出部分还是保持一致。

V2.0并不是最终版本,我们可以暂时不用考虑实现细节,比如这里的「过滤」是怎么实现的。硬要做的话,这种过滤也可以暴力实现:把无效像素初始化为0,每次卷积后再把无效像素置0。

改进之后,V2.0版本的模型确实能快速计算第i个像素的概率分布了。可是,CNN是很擅长同时生成一个和原图像长宽相同的张量的,只算一个像素的概率分布还称不上高效。所以,我们可以让模型输入一幅图像,同时输出图像每一处的概率分布。

这次的改进并不能加速采样。但是,在训练时,由于整幅训练图像已知,我们可以在一次前向传播后得到图像每一处的概率分布。假设图像有N个像素,我们就等于是在并行地训练N个样本,训练速度快了N倍!

这种并行训练的想法和Transformer如出一辙。

V3.0版本的PixelCNN已经和论文里的PixelCNN非常接近了,我们来探讨一下网络的实现细节。相比普通的CNN,PixelCNN有一个特别的约束:第i个像素只能看到前i-1个像素的信息,不能看到第i个像素及后续像素的信息。对于V2.0版本只要输出一个概率分布的PixelCNN,我们可以通过一些简单处理过滤掉第i个像素之后的信息。而对于并行输出所有概率分布的V3.0版本,让每个像素都忽略后续像素的信息的方法就不是那么显然了。

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。具体来说,PixelCNN使用了两类掩码卷积,我们把两类掩码卷积分别称为「A类」和「B类」。二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。A类和B类的唯一区别在于卷积核的中心像素是否产生贡献。CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积。如下图所示。

为什么要先用一次A类掩码卷积,再每次使用B类掩码卷积呢?我们不妨来做一个实验。对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)。

不难看出,经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。再经过两次B类卷积,中心像素能够看到左上角大部分像素的信息。这满足PixelCNN的约束。

而如果一直使用A类卷积,每次卷积后中心像素都会看漏一些信息(不妨对比下面这张示意图和上面那张示意图)。多卷几层后,中心像素的值就会和输入图像毫无关系。

只是用B类卷积也是不行的。显然,如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。这下,我们能明白为什么只能先用一次A类卷积,再用若干次B类卷积了。

利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。除此之外,PixelCNN就没有什么特别的地方了。我们可以用任意一种CNN架构来实现PixelCNN。PixelCNN论文使用了一种类似于ResNet的架构。其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

到目前为止,我们已经成功搭建了处理单通道图像的PixelCNN。现在,我们来尝试把它推广到多通道图像上。相比于单通道图像,多通道图像只不过是一个像素由多个颜色分量组成。我们可以把一个像素的颜色分量看成是子像素。在定义约束关系时,我们规定一个子像素只由它之前的子像素决定。比如对于RGB图像,R子像素由它之前所有像素决定,G子像素由它的R子像素和之前所有像素决定,B子像素由它的R、G子像素和它之前所有像素决定。生成图像时,我们一个子像素一个子像素地生成。

把我们的PixelCNN V3.0推广到RGB图像时,我们要做的第一件事就是修改网络的通道数量。由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量,即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

这里说网络中间的通道数要乘3只是一种方便理解的说法。实际上,中间的通道数可以随意设置,是不是3的倍数都无所谓,只是所有通道在逻辑上被分成了3组。我们稍后会利用到「中间结果的通道数应该能被拆成3组」这一性质。

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  1. 对于通道间的约束,我们要在o, i两个维度上设置掩码。设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3,即o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:oi1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i。序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。我们对输入通道组和输出通道组之间进行约束。对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3;对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3

  2. 对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

就这样,修改了通道数,修改了卷积核的掩码后,我们成功实现了论文里的PixelCNN。让我们把这个过程总结一下。PixelCNN的核心思想是给图像的子像素定义一个先后顺序,之后让每个子像素的颜色取值分布由之前所有的子像素决定。实现PixelCNN时,可以用任意一种CNN架构,并注意两点:

  1. 网络的输出是一个经softmax的概率分布。
  2. 网络的所有卷积层要替换成带掩码的卷积层,第一个卷积层用A类掩码,后面的用B类掩码。

学完了PixelCNN,我们在闲暇之余来谈一谈PixelCNN和其他生成网络的对比情况。精通数学的人,会把图像生成问题看成学习一个图像的分布。每次生成一张图片,就是在图像分布里随机采样一个张量。学习一个分布,最便捷的方法是定义一个带参数$\theta$的概率模型$P_\theta$,最大化来自数据集的图像$\mathbf{x}$的概率$P_\theta(\mathbf{x})$。

可问题来了:一个又方便采样,又能计算概率的模型不好设计。VAE和Diffusion建模了把一个来自正态分布的向量$\mathbf{z}$变化成$\mathbf{x}$的过程,并使用了统计学里的变分推理,求出了$P_\theta(\mathbf{x})$的一个下界,再设法优化这个下界。GAN干脆放弃了概率模型,直接拿一个神经网络来评价生成的图像好不好。

PixelCNN则正面挑战了建立概率模型这一任务。它把$P_\theta(\mathbf{x})$定义为每个子像素出现概率的乘积,而每个子像素的概率仅由它之前的子像素决定。

由于我们可以轻松地用神经网络建模每个子像素的概率分布并完成采样,PixelCNN的采样也是很方便的。我们可以说PixelCNN是一个既方便采样,又能快速地求出图像概率的模型。

相比与其他生成模型,PixelCNN直接对$P_\theta(\mathbf{x})$建模,在和概率相关的指标上表现优秀。很可惜,能最大化数据集的图像的出现概率,并不代表图像的生成质量就很优秀。因此,一直以来,以PixelCNN为代表的对概率直接建模的生成模型没有受到过多的关注。可能只有少数必须要计算图像概率分布的任务才会用到PixelCNN。

除了能直接计算图像的概率外,PixelCNN还有一大特点:PixelCNN能输出离散的颜色值。VAE和GAN这些模型都是把图像的颜色看成一个连续的浮点数,模型的输入和输出的取值范围都位于-1到1之间(有些模型是0到1之间)。而PixelCNN则输出的是像素取某个颜色的概率分布,它能描述的颜色是有限而确定的。假如我们是在生成8位单通道图像,那网络就只输出256个离散的概率分布。能生成离散输出这一特性启发了后续很多生成模型。另外,这一特性也允许我们指定颜色的亮度级别。比如对于黑白手写数字数据集MNIST,我们完全可以用黑、白两种颜色来描述图像,而不是非得用256个灰度级来描述图像。减少亮度级别后,网络的训练速度能快上很多。

在后续的文献中,PixelCNN被归类为了自回归生成模型。这是因为PixelCNN在生成图像时,要先输入空图像,得到第一个像素;把第一个像素填入空图像,输入进模型,得到第二个像素……。也就是说,一个图像被不断扔进模型,不断把上一时刻的输出做为输入。这种用自己之前时刻的状态预测下一个状态的模型,在统计学里被称为自回归模型。如果你在其他图像生成文献中见到了「自回归模型」这个词,它大概率指的就是PixelCNN这种每次生成一个像素,该像素由之前所有像素决定的生成模型。

Gated PixelCNN

首篇提出PixelCNN的论文叫做Pixel Recurrent Neural Networks。没错!这篇文章的作者提出了一种叫做PixelRNN的架构,PixelCNN只是PixelRNN的一个变种。可能作者一开始也没指望PixelCNN有多强。后来,人们发现PixelCNN的想法还挺有趣的,但是原始的PixelCNN设计得太烂了,于是开始着手改进原始的PixelCNN。

PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,在我们刚刚的实验中,中心像素看不到右上角三个本应该能看到的像素。哪怕你对用B类卷积多卷几次,右上角的视野盲区都不会消失。

为此,PixelCNN论文的作者们又打了一些补丁,发表了Conditional Image Generation with PixelCNN Decoders这篇论文。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。如下图所示,Gated PixelCNN使用了两种卷积——垂直卷积和水平卷积——来分别维护一个像素上侧的信息和左侧的信息。垂直卷积的结果只是一些临时量,而水平卷积的结果最终会被网络输出。可以看出,使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

除此之外,Gated PixelCNN还把网络中的激活函数从ReLU换成了LSTM的门结构。Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。此外,模块右侧那部分还有一个残差连接。

除了上面的两项改动,Gated PixelCNN还做出了其他的一些改动。比如,Gated PixelCNN支持带约束的图像生成,比如根据文字生成图片、根据类别生成图片。用于约束生成的向量$\mathbf{h}$会被输入进网络每一层的激活函数中。当然,这些改动不是为了提升原PixelCNN的性能。

PixelCNN++

之后,VAE的作者也下场了,提出了一种改进版的PixelCNN,叫做PixelCNN++。这篇论文没有多余的废话,在摘要中就简明地指出了PixelCNN++的几项改动:

  1. 使用logistic分布代替256路softmax
  2. 简化RGB子像素之间的约束关系
  3. 使用U-Net架构
  4. 使用dropout正则化

这几项改动中,第一项改动是最具启发性的,这一技巧可以拓展到其他任务上。让我们主要学习一下第一项改动,并稍微浏览一下其他几项改动。

离散logistic混合似然

原PixelCNN使用256路softmax建模一个像素的颜色概率分布。这么做确实能让模型更加灵活,但有若干缺点。首先,计算这么多的概率值会占很多内存;其次,由于每次训练只有一个位置的标签为1,其他255个位置的标签都是0,模型可学习参数的梯度会很稀疏;最后,在这种概率分布方式下,256种颜色是分开考虑的,这导致模型并不知道相邻的颜色比较相似(比如颜色值128和127、129比较相似)这一事实。总之,用softmax独立地表示各种颜色有着诸多的不足。

作者把颜色的概率分布建模成了连续分布,一下子克服掉了上述所有难题。让我们来仔细看一下新概率分布的定义方法。

首先,新概率分布使用到的连续分布叫做logistic分布。它有两个参数:均值$\mu$和方差$s^2$。它的概率密度函数为:

logistic分布的概率密度函数看起来比较复杂。但是,如果对这个函数积分,得到的累计分布函数就是logistic函数。如果令均值为0,方差为1,则logistic函数就是我们熟悉的sigmoid函数了。

接着,每个分布可能是$K$个参数不同的logistic分布中的某一个,选择某个logistic分布的概率由$\pi_i$表示。比如$K=2$,$\pi_1 = 0.3, \pi_2=0.7$,就说明有两个可选的logisti分布,每个分布有30%的概率会使用1号logistic分布,有70%的概率会使用2号logistic分布。 这里的$\pi_i$和原来256路softmax的输出的意义一样,都是选择某个东西的概率。当然,$K$会比256要小很多,不然这种改进就起不到减小计算量的作用了。设一个输出颜色为$v$,它的数学表达式为:

可logsitc分布是一个连续分布,而我们想得到256个颜色中某个颜色的概率,即得到一个离散的分布。因此,在最后一步,我们要从上面这个连续分布里得到一个离散的分布。我们先不管$K$和$\pi_i$,只考虑有一个logistic分布的情况。根据统计学知识可知,要从连续分布里得到一个离散分布,可以把定义域拆成若干个区间,对每个区间的概率求积分。在我们的例子里,我们可以把实数集拆成256个区间,令$(-\infty, 0.5]$为第1个区间,$(0.5, 1.5]$为第2个区间,……,$(253.5, 254.5]$为第255个区间, $(254.5, +\infty)$为第256个区间。

对概率密度函数求积分,就是在累积分布函数上做差。因此,对于某个离散颜色值$x\in[0, 255], x\in \mathbb{N}$,已知一个logistic分布$logistic(\mu, s)$,则这个颜色值的出现概率是:

其中,$\sigma()$是sigmoid函数。$\sigma((x-\mu)/s)$就是分布的累积分布函数。

可以看出,使用这种区间划分方法,位于0处和位于255处的颜色的概率相对会高一点。这一特点符合作者统计出的CIFAR-10里的颜色分布规律。

当有$K$个logistic分布时,只要把各个分布的概率做一个加权和就行(公式省略掉了$x$位于边界处的情况)。

至此,我们已经知道了怎么用一个「离散logistic混合似然」来建模颜色的概率分布了。这个更高级的颜色分布以logistic分布为基础,以比例(概率)$\pi_i$混合了$K$个logstic分布,并用巧妙的方法把连续分布转换成了离散分布。

简化RGB子像素之间的约束关系

在原PixelCNN中,生成一个像素的RGB三个子像素时,为了保证子像素之间的约束,我们要把模型中所有特征图的通道分成三组,并用掩码来维持三组通道间的约束。这样做太麻烦了。因此,PixelCNN++对约束做了一定的简化:根据之前所有像素,网络一次性输出三个子像素的均值和方差,而不用掩码区分三个子像素的信息。当然,只是这样做是不够好的——G子像素缺少了R子像素的信息,B子像素缺少了R、G子像素的信息。为了弥补信息的缺失,PixelCNN会为每个像素额外输出三个参数$\alpha, \beta, \gamma$,$\alpha$描述R对G子像素的约束关系,$\beta$描述R对B的约束关系,$\gamma$描述G对B的约束关系。

让我们来用公式更清晰地描述这一过程。对于某个像素的第$i$个logistic分布,网络会输出10个参数:$\pi, \mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$。$\pi$就是之前见过的选择第$i$个分布的概率,$\mu_r, \mu_g, \mu_b$是网络输出的三个子像素的均值,$s_r, s_g, s_b$是网络输出的三个子像素的标准差,$\alpha, \beta, \gamma$描述子像素之间的约束。

由于缺少了其他子像素的信息,网络直接输出的$\mu_g, \mu_b$是不准的。我们假设子像素之间仅存在简单的线性关系。这样,可以用下面的公式更新$\mu_g$和$\mu_b$:

更新后的$\mu_g$和$\mu_b$才是训练和采样时使用的最终均值。

你会不会疑惑上面那个公式里的$r$和$g$是哪里来的?别忘了,虽然子像素之间的约束被简化了,但是三个子像素还是按先后顺序生成的。在训练时,我们是知道所有子像素的真值的,公式里的$r$和$g$来自真值;而在采样时,我们会先用神经网络生成三个子像素的均值和方差,再采样$r$,把采样的$r$套入公式采样出$g$,最后把采样的$r,g$套入公式采样出$b$.

使用U-Net架构

PixelCNN++的网络架构是一个三级U-Net,即网络先下采样两次再上采样两次,同级编码器(下采样部分)的输出会连到解码器(上采样部分)的输入上。这个U-Net和其他任务中的U-Net没什么太大的区别。

使用Dropout

过拟合会导致生成图像的观感不好。为此,PixelCNN++采用了Dropout正则化方法,在每个子模块的第一个卷积后面加了一个Dropout。

除了这些改动外,PixelCNN++还使用了类似于Gated PixelCNN里垂直卷积和水平卷积的设计,以消除原PixelCNN里的视野盲区。当然,这点不算做本文的主要贡献。

总结

PixelCNN把文本生成的想法套入了图像生成中,令子像素的生成有一个先后顺序。为了在维护先后顺序的同时执行并行训练,PixelCNN使用了掩码卷积。这种并行训练与掩码的设计和Transformer一模一样。如果你理解了Transformer,就能一眼看懂PixelCNN的原理。

相比与其他的图像生成模型,以PixelCNN为代表的自回归模型在生成效果上并不优秀。但是,PixelCNN有两个特点:能准确计算某图像在模型里的出现概率(准确来说在统计学里叫做「似然」)、能生成离散的颜色输出。这些特性为后续诸多工作铺平了道路。

原版的PixelCNN有很多缺陷,后续很多工作对其进行了改进。Gated PixelCNN主要消除了原PixelCNN里的视野盲区,而PixelCNN++提出了一种泛用性很强的用连续分布建模离散颜色值的方法,并用简单的线性约束代替了原先较为复杂的用神经网络表示的子像素之间的约束。

PixelCNN相关的知识难度不高,了解本文介绍的内容足矣。PixelCNN也不是很常见的架构,复现代码的优先级不高,有时间的话阅读一下本文附录中的代码即可。另外,PixelCNN的代码实现里有一个重要的知识点。这个知识点几乎不会在论文和网上的文章里看到,但它对实现是否成功有着重要的影响。如果你对新知识感兴趣,推荐去读一下附录中对其的介绍。

参考资料与学习提示

Pixel Recurrent Neural Networks 是提出了PixelCNN的文章。当然,这篇文章主要是在讲PixelRNN,只想学PixelCNN的话通读这篇文章的价值不大。

Conditional Image Generation with PixelCNN Decoders 是提出Gated PixelCNN的文章。可以主要阅读消除视野盲区和门激活函数的部分。

PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications 是提出PixelCNN++的文章。整篇文章非常简练,可以整体阅读一遍,并且着重阅读离散logistic混合似然的部分。不过,这篇文章有很多地方写得过于简单了,连公式里的字母都不好好交代清楚,我还是看代码才看懂他们想讲什么。建议搭配本文的讲解阅读。

这几篇文章都使用了NLL(负对数似然)这个评价指标。实际上,这个指标就是对所有数据在模型里的平均出现概率取了个对数,加了个负号。对于PixelCNN,其NLL就是交叉熵损失函数。其他生成模型不是直接对数据的概率分布建模,它们的NLL不好求得。比如diffusion模型只能计算NLL的一个上界。

网上还有几份PyTorch代码复现供参考:

PixelCNN:https://github.com/singh-hrituraj/PixelCNN-Pytorch

Gated PixelCNN:https://github.com/anordertoreclaim/PixelCNN

附录:代码学习

在附录中,我将给出PixelCNN和Gated PixelCNN的PyTorch实现,并讲解PixelCNN++开源代码的实现细节。

PixelCNN 与 GatedPixelCNN

为了简化实现,我们来实现MNIST上的PixelCNN和Gated PixelCNN。MNIST是单通道数据集,我们不用考虑颜色通道之间复杂的约束。代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/pixelcnn。

我们先准备好数据集。PyTorch的torchvision提供了获取了MNIST的接口,我们只需要用下面的函数就可以生成MNIST的Dataset实例。参数中,root为数据集的下载路径,download为是否自动下载数据集。令download=True的话,第一次调用该函数时会自动下载数据集,而第二次之后就不用下载了,函数会读取存储在root里的数据。

1
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)

我们可以用下面的代码来下载MNIST并输出该数据集的一些信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torchvision
from torchvision.transforms import ToTensor
def download_dataset():
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp.jpg')
tensor = ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

if __name__ == '__main__':
import os
os.makedirs('work_dirs', exist_ok=True)
download_dataset()

执行这段代码,输出大致为:

1
2
3
4
5
6
length of MNIST 60000
<PIL.Image.Image image mode=L size=28x28 at 0x7FB3F09CCE50>
9
torch.Size([1, 28, 28])
tensor(1.)
tensor(0.)

第一行输出表明,MNIST数据集里有60000张图片。而从第二行和第三行输出中,我们发现每一项数据由图片和标签组成,图片是大小为28x28的PIL格式的图片,标签表明该图片是哪个数字。我们可以用torchvision里的ToTensor()把PIL图片转成PyTorch张量,进一步查看图片的信息。最后三行输出表明,每一张图片都是单通道图片(灰度图),颜色值的取值范围是0~1。

我们可以查看一下每张图片的样子。如果你是在用带显示器的电脑,可以去掉img.show那一行的注释,直接查看图片;如果你是在用服务器,可以去img.save的路径里查看图片。该图片的应该长这个样子:

我们可以用下面的代码预处理数据并创建DataLoader。PixelCNN对输入图片的颜色取值没有特别的要求,我们可以不对图片的颜色取值做处理,保持取值范围在0~1即可。

1
2
3
4
5
6
from torch.utils.data import DataLoader

def get_dataloader(batch_size: int):
dataset = torchvision.datasets.MNIST(root='./data/mnist',
transform=ToTensor())
return DataLoader(dataset, batch_size=batch_size, shuffle=True)

准备好数据后,我们来实现PixelCNN和Gated PixelCNN。先从PixelCNN开始。

实现PixelCNN,最重要的是实现掩码卷积。其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2] = 1
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。因此,掩码卷积类里需要实现一个普通卷积的操作。实现普通卷积,既可以写成继承nn.Conv2d,也可以把nn.Conv2d的实例当成成员变量。这份代码使用了后一种实现方法。在__init__里把其他参数原封不动地传给self.conv,并在forward中直接调用self.conv(x)

1
2
3
4
5
6
7
8
9
10
11
12
class MaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
...
self.conv = nn.Conv2d(*args, **kwags)
...

def forward(self, x):
...
conv_res = self.conv(x)
return conv_res

准备好卷积对象后,我们来维护掩码张量。由于输入输出都是单通道图像,按照正文中关于PixelCNN的描述,我们只需要在卷积核的h, w两个维度设置掩码。我们可以用下面的代码生成一个形状为(H, W)的掩码并根据卷积类型对掩码赋值:

1
2
3
4
5
6
7
8
9
10
def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
...
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2] = 1
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1

然后,为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作。

1
mask = mask.reshape((1, 1, H, W))

在初始化函数的最后,我们把用PyTorch API把mask注册成名为'mask'的成员变量。register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中。这样做的好处时,每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。否则的话我们要手动model.mask = model.mask.to(device)转设备。register_buffer的第三个参数表示被注册的变量是否要加入state_dict中以保存下来。由于这里mask每次都会自动生成,我们不需要把它存下来,可以令第三个参数为False

1
self.register_buffer('mask', mask, False)

在前向传播时,只需要先让卷积核乘掩码,再做普通的卷积。

1
2
3
4
def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来。

我们先照着论文实现残差块ResidualBlock。原论文并没有使用归一化,但我发现使用归一化后效果会好一点,于是往模块里加了BatchNorm。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class ResidualBlock(nn.Module):

def __init__(self, h, bn=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(2 * h, h, 1)
self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
self.conv3 = nn.Conv2d(h, 2 * h, 1)
self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

def forward(self, x):
y = self.relu(x)
y = self.conv1(y)
y = self.bn1(y)
y = self.relu(y)
y = self.conv2(y)
y = self.bn2(y)
y = self.relu(y)
y = self.conv3(y)
y = self.bn3(y)
y = y + x
return y

有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class PixelCNN(nn.Module):

def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
super().__init__()
self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
self.residual_blocks = nn.ModuleList()
for _ in range(n_blocks):
self.residual_blocks.append(ResidualBlock(h, bn))
self.relu = nn.ReLU()
self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
self.out = nn.Conv2d(linear_dim, color_level, 1)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
for block in self.residual_blocks:
x = block(x)
x = self.relu(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.out(x)
return x

PixelCNN实现完毕,我们来按照同样的流程实现Gated PixelCNN。首先,我们要实现其中的垂直掩码卷积和水平掩码卷积,二者的实现和PixelCNN里的掩码卷积差不多,只是mask的内容不太一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class VerticalMaskConv2d(nn.Module):

def __init__(self, *args, **kwags):
super().__init__()
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2 + 1] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res


class HorizontalMaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

水平卷积其实只要用一个1x3的卷积就可以实现了。但出于偷懒(也为了方便理解),我还是在3x3卷积的基础上添加的mask

之后我们来用两种卷积搭建论文中的Gated Block。

Gated Block搭起来稍有难度。如上面的结构图所示,我们主要要维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。v会经过一个垂直掩码卷积和一个门激活函数。h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class GatedBlock(nn.Module):

def __init__(self, conv_type, in_channels, p, bn=True):
super().__init__()
self.conv_type = conv_type
self.p = p
self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)
self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
1)
self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_output_conv = nn.Conv2d(p, p, 1)
self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()

def forward(self, v_input, h_input):
v = self.v_conv(v_input)
v = self.bn1(v)
v_to_h = v[:, :, 0:-1]
v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
v_to_h = self.v_to_h_conv(v_to_h)
v_to_h = self.bn2(v_to_h)

v1, v2 = v[:, :self.p], v[:, self.p:]
v1 = torch.tanh(v1)
v2 = torch.sigmoid(v2)
v = v1 * v2

h = self.h_conv(h_input)
h = self.bn3(h)
h = h + v_to_h
h1, h2 = h[:, :self.p], h[:, self.p:]
h1 = torch.tanh(h1)
h2 = torch.sigmoid(h2)
h = h1 * h2
h = self.h_output_conv(h)
h = self.bn4(h)
if self.conv_type == 'B':
h = h + h_input
return v, h

代码中的其他地方都比较常规,唯一要注意的是vh的合成部分。这一部分的实现初看下来比较难懂。为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位,而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)。

1
2
v_to_h = v[:, :, 0:-1]
v_to_h = F.pad(v_to_h, (0, 0, 1, 0))

为什么实际上是要对特征图v下移一个单位呢?实际上,在拼接vh时,我们是想做下面这个计算:

1
2
3
for i in range(H):
for j in range(W):
h[:, :, i, j] += v[:, :, i - 1, j]

但是,写成循环就太慢了,我们最好是能做向量化计算。注意到,vi相加的位置只差了一个单位。为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。

除了vh的合成有点麻烦外,GatedBlock还有一个细节值得注意。h的计算通路中有一个残差连接,但是,在网络的第一层,每个数据是不能看到自己的。所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。

最后,我们来用GatedBlock搭出Gated PixelCNN。Gated PixelCNN和原版PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GatedPixelCNN(nn.Module):

def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__()
self.block1 = GatedBlock('A', 1, p, bn)
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(GatedBlock('B', p, p, bn))
self.relu = nn.ReLU()
self.linear1 = nn.Conv2d(p, linear_dim, 1)
self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
self.out = nn.Conv2d(linear_dim, color_level, 1)

def forward(self, x):
v, h = self.block1(x, x)
for block in self.blocks:
v, h = block(v, h)
x = self.relu(h)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.out(x)
return x

准备好了模型代码,我们可以编写训练和采样的脚本了。我们先用超参数初始化好两个模型。根据论文的描述,PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from dldemos.pixelcnn.dataset import get_dataloader, get_img_shape
from dldemos.pixelcnn.model import PixelCNN, GatedPixelCNN

import torch
import torch.nn as nn
import torch.nn.functional as F

import time
import einops
import cv2

import numpy as np
import os

batch_size = 128
color_level = 8 # or 256

models = [
PixelCNN(15, 128, 32, True, color_level),
GatedPixelCNN(15, 128, 32, True, color_level)
]

if __name__ == '__main__':
os.makedirs('work_dirs', exist_ok=True)
model_id = 1
model = models[model_id]
device = 'cuda'
model_path = f'dldemos/pixelcnn/model_{model_id}_{color_level}.pth'
train(model, device, model_path)
sample(model, device, model_path,
f'work_dirs/pixelcnn_{model_id}_{color_level}.jpg')

之后是训练部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def train(model, device, model_path):
dataloader = get_dataloader(batch_size)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
loss_fn = nn.CrossEntropyLoss()
n_epochs = 40
tic = time.time()
for e in range(n_epochs):
total_loss = 0
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
y = torch.ceil(x * (color_level - 1)).long()
y = y.squeeze(1)
predict_y = model(x)
loss = loss_fn(predict_y, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), model_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')

这部分代码十分常规,和普通的多分类任务十分类似。代码中值得一看的是下面几行:

1
2
3
4
y = torch.ceil(x * (color_level - 1)).long()
y = y.squeeze(1)
predict_y = model(x)
loss = loss_fn(predict_y, y)

这几行代码根据输入x得到了标签y,再做前向传播,最后用预测的predict_yy求交叉熵损失函数。这里第一个要注意的地方是y = y.squeeze(1)这一行。在PyTorch中用交叉熵函数时,标签的形状应该为[N, A, B, ...],预测值的形状应为[N, num_class, A, B, ...]。其中,A,B, ...表示数据的形状。在我们的任务中,数据是二维的,因此标签的形状应为[N, H, W],预测值的形状应为[N, num_class, H, W]。而我们在DataLoader中获得的数据的形状是[N, 1, H, W]。我们要对数据y的形状做一个变换,使之满足PyTorch的要求。这里由于输入是单通道,我们可以随便用squeeze()y长度为1的通道去掉。如果图像是多通道的话,我们则不应该修改y,而是要对预测张量y_predict做一个reshape,改成[N, num_class, C, H, W]

第二个要注意的是y = torch.ceil(x * (color_level - 1)).long()这一行。为什么需要写一个这么复杂的浮点数转整数呢?这个地方的实现需要多解释几句。在我们的代码中,PixelCNN的输入可能来自两个地方:

  1. 训练时,PixelCNN的输入来自数据集。数据集里的颜色值是0~1的浮点数。
  2. 采样时,PixelCNN的输入来自PixelCNN的输出。PixelCNN的输出是整型(别忘了,PixelCNN只能产生离散的输出)。

两种输入,一个是0~1的浮点数,一个是0~color_level-1的整数。为了统一两个输入的形式,最简单的做法是对整型颜色输入做个除法,映射到0~1里,把它统一到浮点数上。

此外,还有一个地方需要类型转换。在训练时,我们需要得到每个像素的标签,即得到每个像素颜色的真值。由于PixelCNN的输出是离散的,这个标签也得是一个离散的颜色。而标签来自训练数据,训练数据又是0~1的浮点数。因此,在计算标签时,需要做一次浮点到整型的转换。这样,整个项目里就有两个重要的类型转换:一个是在获取标签时把浮点转整型,一个是在采样时把整型转浮点。这两个类型转换应该恰好「互逆」,不然就会出现转过去转不回来的问题。

在项目中,我使用了下图所示的浮点数映射到整数的方法。0.0映射到0,(0, 1/255]映射到1,……(254/255, 1]映射到255。即浮点转整型时使用ceil(x*255),整型转浮点的时候使用x/255。这种简单的转换方法保证一个区间里的离散颜色值只会映射到一个整数上,同时把整数映射回浮点数时该浮点数也会落在区间里。如果你随手把浮点转整型写成了int(x*255),则会出现浮点转整数和整数转浮点对应不上的问题,到时候采样的结果会很不好。

由于一个整型只能映射到一个浮点数,而多个浮点数会映射到一个整数,严格来说,大部分浮点数转成整数再转回来是变不回原来的浮点数的。这两个转换过程从数学上来说不是严格的互逆。但是,如果我们马虎一点,把位于同一个区间的浮点数看成等价的,那么浮点数和整数之间的映射就是一个双射,来回转换不会有任何信息损失。

刚才代码中y = torch.ceil(x * (color_level - 1)).long()这一行实际上就是在描述怎样把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的。

再来看看采样部分的代码。和正文里的描述一样,在采样时,我们把x初始化成一个0张量。之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def sample(model, device, model_path, output_path, n_sample=81):

model.eval()
model.load_state_dict(torch.load(model_path))
model = model.to(device)
C, H, W = get_img_shape() # (1, 28, 28)
x = torch.zeros((n_sample, C, H, W)).to(device)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist,
1).float() / (color_level - 1)
x[:, :, i, j] = pixel

imgs = x * 255
imgs = imgs.clamp(0, 255)
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(n_sample**0.5))

imgs = imgs.detach().cpu().numpy().astype(np.uint8)

cv2.imwrite(output_path, imgs)

整个采样代码的核心部分是下面这几行。我们先获取模型的输出,再用softmax转换成概率分布,再用torch.multinomial(prob_dist, 1)从概率分布里采样出一个0~(color_level-1)的离散颜色值,再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色),最后把新像素填入生成图像。

1
2
3
4
5
output = model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist,
1).float() / (color_level - 1)
x[:, :, i, j] = pixel

上面的代码中,如前所述,/ (color_level - 1)与前面的torch.ceil(x * (color_level - 1)).long()必须是对应起来的。两个操作必须「互逆」,不然就会出问题。

当然,最后得到的图像x是一个用0~1浮点数表示的图像,可以直接把它乘255变成一个用8位字节表示的图像,这一步浮点到整型的转换是为了让图像输出,和其他图像任务的后处理是一样的,和PixelCNN对于离散颜色和连续颜色的建模不是同一个意思,不是非得取一次ceil()

PixelCNN训练起来很慢。在代码中,我默认训练40个epoch。原版PixelCNN要花一小时左右训完,Gated PixelCNN就更慢了。

以下是我得到的一些采样结果。首先是只有8个颜色级别的PixelCNN和Gated PixelCNN。

可以看出,PixelCNN经常会生成一些没有意义的「数字」,而Gated PixelCNN生成的大部分数字都是正常的。但由于颜色级别只有8,模型偶尔会生成较粗的色块。这个在Gated PixelCNN的输出里比较明显。

之后看一下正常的256个颜色级别的PixelCNN和Gated PixelCNN采样结果。

由于颜色级别增大,任务难度变大,这两个模型的生成效果就不是那么好了。当然,Gated PixelCNN还是略好一些。训练效果差,与MNIST的特性(大部分像素都是0和255)以及PixelCNN对于离散颜色的建模有关。PixelCNN的这一缺陷已经在PixelCNN++论文里分析过了。

PixelCNN++ 源码阅读

PixelCNN++在实现上细节颇多,复现起来难度较大。而且它的官方实现是拿TensorFlow写的,对于只会PyTorch的选手来说不够友好。还好,PixelCNN++的官方实现非常简练,核心代码只有两个文件,没有过度封装,也没有过度使用API,哪怕不懂TensorFlow也不会有障碍(但由于代码中有很多科学计算,阅读起来没有障碍,却难度不小)。让我们来通过阅读官方源码来学习PixelCNN++的实现。

官方代码的地址在 https://github.com/openai/pixel-cnn 。源码有两个核心文件:nn.py实现了网络模块及一些重要的训练和采样函数,model.py定义了网络的结构。让我们自顶向下地学习,先看model.py,看到函数调用后再跑到nn.py里查看实现细节。

model.py里就只有一个函数model_spec,它定义了神经网络的结构。
它的参数为:

1
2
3
4
5
6
7
8
9
10
def model_spec(x, 
h=None,
init=False,
ema=None,
dropout_p=0.5,
nr_resnet=5,
nr_filters=160,
nr_logistic_mix=10,
resnet_nonlinearity='concat_elu',
energy_distance=False):

各参数的意义为:

  • x: 形状为[N, H, W, D1]的输入张量。其中,D1表示输入通道数。对于RGB图像,D1=3
  • h: 形状为[N, K]的约束条件,即对于每个batch来说,约束条件是一个长度K的向量。这里的约束条件和Gatd PixelCNN中提出的一样,可以是文字,也可以是类别,只要约束条件最终被转换成一个向量就行。
  • init: 是否执行初始化。这和TensorFlow的实现有关,可以不管。
  • ema: 对参数使用指数移动平均,一种训练优化技巧,和论文无关,可以不管。
  • dropout_p: dropout的概率。
  • nr_resnet: U-Net每一块里有几个ResNet层(U-Net一共有6块,编码器3块解码器3块)。
  • nr_filters: 每个卷积层的卷积核个数,即所有中间特征图的通道数。
  • nr_logistic_mix: 论文里的$K$,表示用几个logistic分布混合起来描述一个颜色分布。
  • resnet_nonlinearity: 激活函数的类别。
  • energy_distance:是否使用论文里没提过的一种算损失函数的办法,可以不管。

之后来看函数体。20行with arg_scope ([nn.conv2d, ...], counters=counters, ...)大概是说进入了TensorFlow里的arg_scope这个上下文。只要在上下文里,后面counters等参数就会被自动传入nn.conv2d等函数,而不需要在函数里显式传参。这样写会让后面的函数调用更简短一点。

22行至30行在选择激活函数,可以直接跳过。

1
2
3
4
5
6
7
8
9
# parse resnet nonlinearity argument
if resnet_nonlinearity == 'concat_elu':
resnet_nonlinearity = nn.concat_elu
elif resnet_nonlinearity == 'elu':
resnet_nonlinearity = tf.nn.elu
elif resnet_nonlinearity == 'relu':
resnet_nonlinearity = tf.nn.relu
else:
raise('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported')

从35行开始,函数正式开始定义网络结构。一开始,代码里有一个匪夷所思的操作:先是取出输入张量的形状xs,再根据这个形状给x填充了一个全是1的通道。

1
2
xs = nn.int_shape(x)
x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

虽然作者加了注释,说这个x_pad后面会用到。但我翻遍了代码,楞是没找到这个多出来的通道发挥了什么作用。GitHub issue里也有人提问,问这个x_pad在做什么。有其他用户给了回复,说他尝试了去掉填充,结果不变。可见这一行代码确实是毫无贡献,还增加了不必要的计算量。大概是作者没删干净过时的实现代码。

之后的几行是在初始化上卷积和左上卷积的中间结果(上卷积和Gated PixelCNN里的垂直卷积类似,左上卷积和Gated PixelCNN里的水平卷积类似)。u_list会保存所有上卷积在编码器里的结果,ul_list会保存所有左上卷积在编码器里的结果。这些结果会供解码器使用。

1
2
3
4
5
6
7
8
9
10
11
12
 u_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[2, 3])
)] # stream for pixels above
ul_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[1,3])
) + nn.right_shift(
nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1])
)] # stream for up and to the left

作者没有使用带掩码的卷积,而是通过普通卷积加偏移等效实现了掩码卷积。这一实现非常巧妙,效率更高。我们来看看这几个卷积的实现方法。

首先看上卷积down_shifted_conv2d,它表示实现一个卷积中心在卷积核正下方的卷积。作者使用了[2,3]的卷积核,并手动给卷积填充(注意,卷积的类型是'valid'不是'same')。这种卷积等价于我们做普通的3x3卷积再给上面6个像素打上掩码。

1
2
3
def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs):
x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]])
return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)

作者在down_shifted_conv2d之后跟了一个down_shift。这个操作和我们实现Gated PixelCNN时移动v_to_h张量的做法一样,去掉张量最下面一行,在最上面一行填0,也就是让张量往下移了一格。

1
2
3
def down_shift(x):
xs = int_shape(x)
return tf.concat([tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]],1)

类似地,在做第一次左上卷积时,作者把一个下移过的1x3卷积结果和一个右移过的2x1卷积结果拼到了一起。其中,down_right_shifted_conv2d就是实现一个卷积中心在卷积核右下角的卷积。

1
2
3
4
5
6
7
ul_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[1,3])
) + nn.right_shift(
nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1])
)]

初始化完毕后,数据就正式进入了U-Net。让我们先略过函数的细节,看一看模型的整体架构。在下采样部分,三级U-Net在每一级都是先经过若干个gated_resnet模块,再下采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

之后是上采样。类似地,数据先经过若干个gated_resnet模块,再上采样。与前半部分不同的是,前半部分的输出会从u_listul_list中逐个取出(实际上这两个list起到了一个栈的作用),接入到gated_resnet的输入里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
u = u_list.pop()
ul = ul_list.pop()
for rep in range(nr_resnet):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

for rep in range(nr_resnet+1):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

for rep in range(nr_resnet+1):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

模型U-Net的部分到此为止。整个网络的结构并不复杂,我们只要看懂了nn.gated_resnet的实现,就算理解了整个模型的实现。让我们来详细看一下这个模块是怎么实现的。以下是整个模块的实现代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
xs = int_shape(x)
num_filters = xs[-1]

c1 = conv(nonlinearity(x), num_filters)
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
c2 = conv(c1, num_filters * 2, init_scale=0.1)

# add projection of h vector if included: conditional generation
if h is not None:
with tf.variable_scope(get_name('conditional_weights', counters)):
hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
if init:
hw = hw.initialized_value()
c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])

a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)
return x + c3

照例,我们来先看一下函数的每个参数的意义。

1
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs)
  • x: 模块的输入。
  • a: 模块的附加输入。附加输入有两个来源:上方u_list的信息传递给左上方ul_list的信息、编码器把信息传递给解码器。
  • h: 形状为[N, K]的约束条件。从模型的参数里传递而来。
  • nonlinearity: 激活函数。从模型的参数里传递而来。
  • conv:卷积操作的函数。可能是上卷积或者左上卷积。
  • init: 是否执行初始化。这和TensorFlow的实现有关,可以不管。
  • counters: 作者写的一个用于方便地给模块的命名的字典,可以不管。
  • ema: 对参数使用指数移动平均。从模型的参数里传递而来。
  • dropout_p: dropout的概率。从模型的参数里传递而来。

模块主要是做了下面这些卷积操作。一开始,先对输入x做卷积,得到c1。如果有额外输入a,则对a做一个1x1卷积(作者自己实现了1x1卷积,把函数命名为nin),加到c1上。做完第一个卷积后,过一个dropout层。最后再卷积一次,得到2*num_filters通道数的张量。

1
2
3
4
5
6
7
c1 = conv(nonlinearity(x), num_filters)
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
c2 = conv(c1, num_filters * 2, init_scale=0.1)

之后,作者也使用了一种门结构作为整个模块的激活函数。但是和Gated PixelCNN相比,PixelCNN++的门结构简单一点。详见下面的代码。

1
2
a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)

最后输出时,c3和输入x之间有一个残差连接。

1
return x + c3

看完gated_resnet的实现,我们可以跳回去继续看模型结构了。经过了U-Net的主体结构后,只需要经过一个输出层就可以得到最终的输出了。输出层里,作者用1x1卷积修改了输出通道数,令最后的通道数为10*nr_logistic_mix

1
2
3
4
5
6
7
8
9
if energy_distance:
# 跳过
else:
x_out = nn.nin(tf.nn.elu(ul),10*nr_logistic_mix)

assert len(u_list) == 0
assert len(ul_list) == 0

return x_out

大家还记得这个10是从哪里来的吗?在正文中,我们曾经学过,对于某个像素的第$i$个logistic分布,网络会输出10个参数:$\pi, \mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$。这个10就是10个参数的意思。

光知道一共有10个参数还不够。接下来就是PixelCNN++比较难懂的部分——怎么用这些参数构成一共logistic分布,并从连续分布中得到离散的概率分布。这些逻辑被作者写在了损失函数nn.discretized_mix_logistic_loss里面。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def discretized_mix_logistic_loss(x,l,sum_all=True):
""" log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:,:,:,:nr_mix]
l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
means = l[:,:,:,:,:nr_mix]
log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])
x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)
centered_x = x - means
inv_stdv = tf.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = tf.nn.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = tf.nn.sigmoid(min_in)
log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5))))

log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)
if sum_all:
return -tf.reduce_sum(log_sum_exp(log_probs))
else:
return -tf.reduce_sum(log_sum_exp(log_probs),[1,2])

这个函数很长,很难读。它实际上可以被拆成四个部分:取参数、求均值、求离散概率、求和。让我们一部分一部分看过来。

首先是取参数部分,这部分代码如下所示。模型一共输出了10*nr_mix个参数,即输出了nr_mix组参数,每组有10个参数。如前所述,第一个参数是选择该分布的未经过softmax的概率logit_probs,之后的6个参数是三个通道的均值及三个通道的标准差取log,最后3个参数是描述通道间依赖关系的$\alpha, \beta, \gamma$。不用去认真阅读这段代码,只需要知道这些代码可以把数据取出来即可。

1
2
3
4
5
6
7
8
xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:,:,:,:nr_mix]
l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
means = l[:,:,:,:,:nr_mix]
log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])

之后是求均值部分。在第一行,作者用了一种曲折的方式实现了repeat操作,把x在最后一维重复了nr_mix次,方便后续处理。在第二第三行,作者根据论文里的公式,调整了G通道和B通道的均值。在最后第四行,作者把所有均值张量拼到了一起。

1
2
3
4
x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)

再来是求离散概率部分。作者根据论文里的公式,算出了当前离散分布的积分上限和积分下限(通过从累计分布密度函数里取值),再做差,得到了离散分布的概率。由于最终的概率值要求log,作者没有按照公式的顺序先算累计分布概率函数的值,再取log,而是把所有计算放到一起并化简。这样代码虽然难读了一点,但减少了不必要的计算,也减少了精度损失。

1
2
3
4
5
6
7
8
9
 centered_x = x - means
inv_stdv = tf.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = tf.nn.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = tf.nn.sigmoid(min_in)
log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases

作者还算了积分区间中心的概率,以处理某些边界情况。实际上这个值没有在代码中使用。

1
2
3
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in)
# log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

光做差还不够。为了处理颜色值在0和255的边界情况,作者还给代码加入了一些边界上的特判,才得到了最终的概率log_probs
1
2
3
4
5
log_probs = tf.where(x < -0.999, log_cdf_plus, 
tf.where(x > 0.999, log_one_minus_cdf_min,
tf.where(cdf_delta > 1e-5,
tf.log(tf.maximum(cdf_delta, 1e-12)),
log_pdf_mid - np.log(127.5))))

最后是loss求和部分。除了要把离散概率的对数求和外,还要加上选择这个分布的概率的对数。log_prob_from_logits就是做一个softmax再求一个log。算上了选择分布的概率后,再对loss求一次和,就得到了最终的loss。

1
2
3
4
5
log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)
if sum_all:
return -tf.reduce_sum(log_sum_exp(log_probs))
else:
return -tf.reduce_sum(log_sum_exp(log_probs),[1,2])

至此,我们就看完了训练部分的关键代码。我们再来看一看采样部分最关键的代码,怎么从logisitc分布里采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def sample_from_discretized_mix_logistic(l,nr_mix):
ls = int_shape(l)
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])
# sample mixture indicator from softmax
sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])
# select logistic parameters
means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)
log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)
coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))
x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)
x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)
x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)
return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)

一开始,还是和刚刚的求loss一样,作者把参数从网络输出l里拆出来。logit_probs是选择某分布的未经softmax的概率,其余的参数是均值、标准差、通道间依赖参数。

1
2
3
4
5
6
def sample_from_discretized_mix_logistic(l,nr_mix):
ls = int_shape(l)
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])

之后,作者对logit_probs做了一个softmax,得到选择各分布的概率。之后,作者根据这个概率分布采样,从nr_mix个logistic分布里选了一个做为这次生成使用的分布。作者没有使用下标来选择数据,而是把选中的序号编码成one-hot向量sel,通过乘one-hot向量来实现从某数据组里取数。

1
2
sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])

接着,作者根据sel,取出nr_mix个logistic分布中某一个分布的均值、标准差、依赖系数。

1
2
3
4
# select logistic parameters
means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)
log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)
coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)

再然后,作者用下面两行代码完成了从logistic分布的采样。从一个连续概率分布里采样是一个基础的数学问题。其做法是先求概率分布的累计分布函数。由于累计分布函数可以把自变量一一映射到0~1之间的概率,我们就得到了一个0~1之间的数到自变量的映射,即累积分布函数的反函数。通过对0~1均匀采样,再套入累积分布函数的反函数,就完成了采样。下面第二行计算其实就是在算logisitc分布的累积分布函数的反函数的一个值。

1
2
u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))

只从分布里采样还不够,我们还得算上依赖系数。把依赖系数的贡献算完后,整个采样就结束了,我们得到了RGB三个颜色值。

1
2
3
4
x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)
x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)
x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)
return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)

至此,PixelCNN++中最具有学习价值的代码就看完了。让我再次总结一下PixelCNN++中的重要代码,并介绍一下学习它们需要什么前置知识。

PixelCNN++中第一个比较重要的地方是掩码卷积的实现。它没有真的使用到掩码,而是使用了卷积中心在卷积核下方和右下角的卷积来等价实现。要读懂这些代码,你需要先看懂PixelCNN和Gated PixelCNN里面对于掩码卷积的定义,知道PixelCNN++为什么要做两种卷积。之后,你还需要对卷积操作有一点基础的认识,知道卷积操作的填充方式其实是在改变卷积中心在卷积核中的位置。你不需要懂太多TensorFlow的知识,毕竟卷积的API就那么几个参数,每个框架都差不多。

PixelCNN++的另一个比较重要的地方是logistic分布的离散概率计算与采样。为了学懂这些,你需要一点比较基础的统计学知识,知道概率密度函数与累积分布函数的关系,知道怎么用计算机从一个连续分布里采样。之后,你要读懂PixelCNN++是怎么用logistic分布对离散概率建模的,知道logistic分布的累计分布函数就是sigmoid函数。懂了这些,你看代码就不会有太多问题,代码基本上就是对论文内容的翻译。反倒是如果读论文没读懂,可以去看代码里的实现细节。

最近,我需要在Python里使用PatchMatch算法(一种算两张图片逐像素匹配关系的算法)。我去网上搜了一份实现,跑了下测试程序,发现它跑边长300像素的图片都要花三分钟。这个速度实在太慢了。我想起以前瞟到过一篇介绍Python加速的文章,里面提到过Numba这个库。于是,我现学现用,最终成功用Numba让原来要跑180秒的程序在0.6秒左右跑完。可见,Numba学起来是很快的。在这篇文章中,我将以这个Python版PatchMatch项目为例,介绍如何快速从零上手Numba,以大幅加速Python科学计算程序。这篇文章不会涉及PatchMatch算法的原理,只要你写过Python,就能读懂本文。

缘起:一份缓慢的 PatchMatch 实现

PatchMatch是Adobe提出的一种快速计算两张图片逐像素匹配关系的算法。也就是说,输入两张类似的图片A和B(比如视频里的连续两帧),算法能输出图片A中的每个像素对应B中的哪个像素(可能会出现多对一的情况)。为了快速验证算法的效果,我们可以输入图片A和B,用算法获取A到B的匹配关系,再根据匹配关系从B中取像素重建A。如果重建出来的图片和原来的图片A看上去差不多,那算法的效果就很不错。下图是一份PatchMatch测试程序的输出。

我在GitHub上找到了一份简明实用的Python版PatchMatch实现,得到了上面的输出结果。结果是挺不错,但哪怕是跑326x244这么小的图片,都要花约180秒才能跑完。

我懒得从头学一遍PatchMatch,决定直接上手优化代码。代码不长,其函数调用关系能很快理清。

代码入口函数是NNS()。它先是调用了initialization(),再循环itr次,每次遍历所有像素,对每个像素调用propagation()random_search()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
f, dist, img_padding = initialization(img, ref, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, False)
random_search(f, a, dist, img_padding, ref, p_size)
else:
for i in range(A_h):
for j in range(A_w):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, True)
random_search(f, a, dist, img_padding, ref, p_size)
return f

initialization()先是定义了一些变量,之后对所有像素调用cal_distance()

1
2
3
4
5
6
7
8
9
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
...
for i in range(A_h):
for j in range(A_w):
...
dist[i, j] = cal_distance(a, b, A_padding, B, p_size)
return f, dist, A_padding

propagation()主要调用了一次cal_distance()

1
2
3
4
5
6
7
8
9
10
11
12
def propagation(f, a, dist, A_padding, B, p_size, is_odd):
...
if is_odd:
if idx == 1:
...
dist[x, y] = cal_distance(a, f[x, y], A_padding, B, p_size)
if idx == 2:
...
dist[x, y] = cal_distance(a, f[x, y], A_padding, B, p_size)
else:
# 和 is_odd 时类似
...

random_search()则主要是在一个while循环里反复调用cal_distance()

1
2
3
4
5
6
def random_search(f, a, dist, A_padding, B, p_size, alpha=0.5):
...
while search_h > 1 and search_w > 1:
...
d = cal_distance(a, b, A_padding, B, p_size)
...

最后来看被调用最多的cal_distance()。这个函数用于计算图片A,B之间的某个距离。也别管这个距离是什么意思,总之是这一个有点耗时的计算。

1
2
3
4
5
6
7
8
9
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
temp = patch_b - patch_a
num = np.sum(1 - np.int32(np.isnan(temp)))
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

至此,这份程序就差不多看完了。可以发现,代码大部分时候都在遍历像素,且遍历每个像素时多次调用cal_distance()函数。而我们知道,拿Python本身做计算是很慢的,尤其是在一个很长的循环里反复计算。这份代码性能较低,正是因为代码在遍历每个像素时做了大量计算。

我以前看过一篇文章,说Numba库能够加速Python科学计算程序,尤其是加速带有大量循环的程序。于是,我去学习了一下Numba的基础用法。

Numba 基础

Numba的官方文档提供了非常友好的入门教程。我们来大致把教程过一下。

Numba可以用pip一键安装。

1
pip install numba

Numba尤其擅长加速循环以及和NumPy相关的计算。使用@jit(nopython=True)(或@njit)装饰一个函数后,我们可以在这个函数里随便写循环,随便用NumPy计算,就像在用C语言一样。经Numba优化后,这个函数会跑得飞快。以下是官方给出的入门示例程序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from numba import jit, njit
import numpy as np

x = np.arange(100).reshape(10, 10)


@jit(nopython=True) # 设置 "nopython" 模式以获取最优性能,等价于 @njit
def go_fast(a): # 初次调用时函数将被编译成机器码
trace = 0.0
for i in range(a.shape[0]): # Numba 喜欢循环
trace += np.tanh(a[i, i]) # Numba 喜欢 NumPy 函数
return a + trace # Numba 喜欢 NumPy 广播


print(go_fast(x))

Numba是怎么完成加速的呢?从装饰器名jit(JIT,Just-In-Time Compiler的简称)中,我们能猜出,Numba使用了即时编译技术,把函数直接翻译成了机器码,而没有像普通Python程序一样解释执行。Numba有两种编译模式,最常见的模式是令参数nopython=True,在编译中完全不用Python解释器。这种模式下,函数能以最优性能翻译成机器码。

修改上面的代码,我们可以测试该函数的速度。注意,由于采用了即时编译,函数在初次调用时会被编译。如果只要计算函数在编译后的运行时间,应该从第二次调用后开始计时。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from numba import jit
import numpy as np
import time

x = np.arange(100).reshape(10, 10)


@jit(nopython=True)
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace


# 不要汇报这个速度,因为编译时间也被算进去了
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
print("Elapsed (with compilation) = {}s".format((end - start)))

# 现在函数已经被编译了,用缓存好的函数重新计时
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
print("Elapsed (after compilation) = {}s".format((end - start)))

我们可以得到类似于下面的输出:

1
2
Elapsed (with compilation) = 1.0579542s
Elapsed (after compilation) = 1.7699999999898353e-05s

我们可以尝试一下不用Numba,直接用Python循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
import time

x = np.arange(100).reshape(10, 10)


def go_slowly(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace


start = time.perf_counter()
go_slowly(x)
end = time.perf_counter()
print("Elapsed (without Numba) = {}s".format((end - start)))

这个速度(4e-4)比用Numba慢了一个数量级。

1
Elapsed (without Numba) = 0.00046979999999985367s

也就是说,我们只要在普通的Python计算函数上加一个@jit(nopython=True)(或@njit),其他什么都不用做,就可以加速代码了。让我们来用它改进一下之前的PatchMatch程序。

用Numba计时编译加速PatchMatch

让我们开始做PatchMatch的性能调优。首先,根据性能优化的一般做法,我们要得知每一行函数调用的运行时间,找到性能瓶颈,从瓶颈处开始优化。我们可以用line_profiler来分析每一行代码的运行时间。用pip即可安装这个库。

1
pip install line_profiler

把主函数修改一下,在调用算法入口函数时拿LineProfiler封装一下,再用lp.add_function添加想监控的函数,即可开始性能分析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if __name__ == "__main__":
img = np.array(Image.open("./cup_a.jpg"))
ref = np.array(Image.open("./cup_b.jpg"))
p_size = 3
itr = 5

# start = time.time()
# f = NNS(img, ref, p_size, itr)
# end = time.time()
# print(end - start)

# reconstruction(f, img, ref)

lp = LineProfiler()
lp_wrapper = lp(NNS)
lp.add_function(propagation)
lp.add_function(random_search)
f = lp_wrapper(img, ref, p_size, itr)
lp.print_stats()

性能分析结果会显示每一行代码的运行时间及占用时间百分比。从结果中可以看出,在入口函数NNS()中,random_search()最为耗时。这是符合预期的,因为random_search()里还有一层while循环。

现在,我们应该着重优化random_search()的性能。我们继续查看一下random_search()的性能分析结果。

结果显示,绝大多数时间都消耗在了while循环里。也和我们之前分析得一样,cal_distance()是耗时最多的一行。除了random_search()外,其他几个函数也多次调用了cal_distance()。因此,我们目前代码优化的目标就定格在了cal_distance()身上。

刚刚学完了Numba,这不正好可以用上了吗?我们可以尝试直接给cal_distance()加一个@njit装饰器。

1
2
3
4
5
6
7
8
9
@njit
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
temp = patch_b - patch_a
num = np.sum(1 - np.int32(np.isnan(temp)))
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

修改完代码后,再次运行程序。这次,程序报了一大堆错误。大致是说,在某一行碰到了Numba识别不了的函数。应该把np.int32()的强制类型转换改成.astype(np.int32)

1
2
num = np.sum(1 - np.int32(np.isnan(temp)))
^

改完之后,如果Numba版本较老,还会碰到新的报错:

1
Use of unsupported NumPy function 'numpy.nan_to_num' or unsupported use of the function.

报错显示,numpy.nan_to_num函数没有得到支持。再次翻阅Numba文档,可以发现,Numba并不支持所有NumPy函数。Numba对NumPy的支持情况可以在文档里查询(需要把文档切换到你当前Numba的版本)。

总之,cal_distance()这个函数不改不行了。得认真阅读一下这个函数的原理。原来,cal_distance(a, b, A_padding, B, p_size)函数是算图像A_padding和图像B中某一个像素块的均方误差的平均值,其中,像素块的边长为p_size,像素块在A_padding的坐标由a表示,在B中的坐标由b表示。

1
2
3
4
5
6
7
8
9
10
11
12
13
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
# 根据坐标a和边长p从A_padding里取像素块
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
# 根据坐标b和边长p从A_padding里取像素块
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
# 求差
temp = patch_b - patch_a
# 根据非nan像素数量算有效像素数量
num = np.sum(1 - np.isnan(temp).astype(np.int32))
# 排除nan,求差的平方和,再除以有效像素数量
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

代码里还有一些奇怪的有关nan的运算:如果像素块里某处有nan,就说明此处像素无效,不应该参与均方误差的运算。为什么图像里会有nan呢?我们得阅读代码的其他部分。

nan是在初始化函数initialization()里加入的。A_padding原来是图像A在周围填了一圈nan的结果。我们大致能猜测出作者填充nan的原因:从A中取像素块时,若像素块在边缘,则有一些像素就不应该被计算了。拿条件语句判断这些无效像素比较麻烦,作者选择干脆在图像A周围填一圈nan,保证每次取像素块时不用判断无效像素。等算误差的时候再判断根据nan排除无效像素。

1
2
3
4
5
6
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
p = p_size // 2
A_padding = np.ones([A_h + p * 2, A_w + p * 2, 3]) * np.nan
A_padding[p:A_h + p, p:A_w + p, :] = A

使用nan填充,既耗时,兼容性又不好。为了尽可能加速cal_distance(),我把填充改成了edge填充,即让填充值等于边界值,并取消了无效像素的判断。也就是说,若像素块取到了图像外的像素,则认为这个像素和边界处的像素一样。这个假设是很合理的,这种修改几乎不会损耗算法的效果。

除此之外,为了进一步减少cal_distance()中的计算,我把要用到的变量都提前在外面算好再传进来。由于现在不需要考虑无效像素的数量,可以直接对误差求和,不用再算平均值,少做一次除法。还有,现在用@njit装饰了函数,可以放心大胆地在循环里做计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
p = p_size // 2
A_padding = np.pad(A, ((p, p), (p, p), (0, 0)), mode='edge')

# Numba 循环写法
@njit
def cal_distance(x, y, x2, y2, A_padding, B, p):
sum = 0
for i in range(p + p + 1):
for j in range(p + p + 1):
for k in range(3):
a = float(A_padding[x + i, y + j, k])
bb = B[x2 - p + i, y2 - p + j, k]
sum += (a - bb)**2
return sum

当然,用NumPy实现cal_distance也是可以的。

1
2
3
4
5
# NumPy 等价写法,加上@njit更快
def cal_distance(x, y, x2, y2, A_padding, B, p):
patch_a = A_padding[x:x + p, y:y + p, :].astype(np.float32)
patch_b = B[x2 - p:x2 + p + 1, y2 - p:y2 + p + 1, :]
return np.sum((patch_a - patch_b)**2)

经测试,把nan的判断全部去掉后,使用NumPy版的cal_distance(),程序的运行时间降到了60秒。给NumPy版的cal_distance()加上@njit,运行时间进一步降低到了33秒。而如果使用带@njit装饰的循环写法,则运行时间也差不多是33秒,甚至还略快一些。这些测试结果印证了Numba的特性:

  1. Numba可以加速和NumPy张量相关的计算
  2. 在Numba中使用循环不会降低运行速度

成功用@njit优化完了代码中最深层的cal_distance(),我们会想,是不是所有函数都可以用同样方法加速?我们可以来做个实验,给最外层的入口函数NNS()加上@njit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@njit
def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
f, dist, img_padding = initialization(img, ref, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, False)
random_search(f, a, dist, img_padding, ref, p_size)
else:
for i in range(A_h):
for j in range(A_w):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, True)
random_search(f, a, dist, img_padding, ref, p_size)
return f

运行程序,会得到类似于下面的报错:

1
Untyped global name 'initialization': Cannot determine Numba type of <class 'function'>

把报错放网上一搜,原来,@njit的自定义函数只能调用加@njit的自定义函数。也就是说,在上面这份代码里,我们虽然用@njit装饰了NNS(),但我们自己定义的initialization(), propagation(),random_search()全部都没有用@njit装饰,因此NNS()的编译会出错。看来,我们得自底向上一步一步加上@njit了。

先来尝试修改一下initialization()。很可惜,直接加上@njit会报错。

1
2
3
4
5
6
7
8
9
10
11
12
13
@njit
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
B_h = np.size(B, 0)
B_w = np.size(B, 1)
p = p_size // 2
random_B_r = np.random.randint(p, B_h - p, [A_h, A_w])
random_B_c = np.random.randint(p, B_w - p, [A_h, A_w])
...

#报错
# Use of unsupported NumPy function 'numpy.size' or unsupported use of the function.

报错是说有不支持的NumPy函数numpy.size。实际上,不仅是numpy.size,Numba也不支持有三个参数的np.random.randint。为了解决此问题,和刚刚对numpy.nan_to_num的处理一样,最好是能用其他等价写法来代替不支持的函数。如果不行的话,则应该把不支持的运算和支持的运算分离开,只加速支持的那一部分。对于initialization(),我采用了第二种解决方法,把函数中耗时的循环拆开来单独用@njit装饰,其余有不支持的NumPy函数的部分就不用Numba优化了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@njit
def initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p):
for i in range(A_h):
for j in range(A_w):
x, y = random_B_r[i, j], random_B_c[i, j]
f[i, j, 0] = x
f[i, j, 1] = y
dist[i, j] = cal_distance(i, j, x, y, A_padding, B, p)


def initialization(A, B, A_h, A_w, B_h, B_w, p_size):
p = p_size // 2
random_B_r = np.random.randint(p, B_h - p, [A_h, A_w])
random_B_c = np.random.randint(p, B_w - p, [A_h, A_w])
A_padding = np.pad(A, ((p, p), (p, p), (0, 0)), mode='edge')
f = np.zeros([A_h, A_w, 2], dtype=np.int32)
dist = np.zeros([A_h, A_w])
initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p)
return f, dist, A_padding

另外的两个函数propagation()random_search()只会碰到取形状函数numpy.size的问题。这个问题很好解决,只要把numpy.size挪到函数调用外即可。

initialization_loop()propagation()random_search()都加上@njit后,程序的运行时间从33秒猛地降到了3秒左右。可以说,只用加@njit的方法的话,程序已经没有优化空间了。

用Numba提前编译加速PatchMatch

又看了看PatchMatch的源码,我发现,PatchMatch算法会先为每个像素随机生成一个匹配关系。然后,算法会迭代更新匹配关系。迭代得越久,匹配关系越准。而我之后要用PatchMatch处理一段视频,算所有帧对第1帧的匹配关系。那么,对于视频这种连续的图像序列,我能不能让第3帧初始化匹配关系时复用第2帧的匹配结果,第4帧复用第3帧的匹配关系,以此类推,以减少迭代次数呢?

说干就干。我准备先测试一下减少迭代次数后代码运行时间能缩短多少。迭代次数itr是在main函数里指定的,作者默认的数值是5。我把它改成1测试了一下。

1
2
3
4
5
6
7
8
9
10
11
if __name__ == "__main__":
img = np.array(Image.open("./cup_a.jpg"))
ref = np.array(Image.open("./cup_b.jpg"))
p = 3
itr = 1

start = time.time()
f = NNS(img, ref, p, itr)
end = time.time()
print(end - start)
reconstruction(f, img, ref)

结果,原来要花3秒的程序还是要花接近3秒,时间缩短得非常不明显。这不应该啊,理论上程序的运行时间应该大致和itr成正比啊。

测试了半天,我突然想起Numba文档里讲过,@njit是即时编译,函数的编译会在初次调用函数时完成。我每次运行程序时,大部分时间都花在了编译上,因此整个程序的运行时间几乎不由迭代次数决定。

我之后要反复运行PatchMatch程序,而不是通过运行一次程序来处理大批数据。即时编译的代价我是接受不了的。于是,我去文档里找到了Numba提前编译(AOT,ahead of time)的使用方法。

Numba AOT可以把Python函数编译进一个模块文件中。想在其他地方调用被编译的函数时,只需要import 模块名即可 。

官方给出的Numba AOT示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from numba.pycc import CC

cc = CC('my_module')

@cc.export('multf', 'f8(f8, f8)')
@cc.export('multi', 'i4(i4, i4)')
def mult(a, b):
return a * b

@cc.export('square', 'f8(f8)')
def square(a):
return a ** 2

if __name__ == "__main__":
cc.compile()

首先,程序要用一个模块名实例化一个CC。该模块名是未来我们import时用到的名称。之后,对于想编译的函数,我们要用@cc.export装饰它。@cc.export的第一个参数是调用时的函数名(原来的函数名会被舍弃),第二个参数用于指定函数返回值和参数的类型。做完所有这些准备后,使用cc.compile()即可完成编译。

运行该程序,会得到一个模块文件。根据平台的不同,该模块文件名可能是my_module.somy_module.pydmy_module.cpython-34m.so。不管文件名是什么,只要是在同一个文件夹下,我们就可以用下面的Python命令调用这个模块文件。

1
2
3
4
5
>>> import my_module
>>> my_module.multi(3, 4)
12
>>> my_module.square(1.414)
1.9993959999999997

用Numba做即时编译时,函数的返回值类型和参数类型可填可不填。而Numba提前编译中,必须要填入函数的返回值类型和参数类型。这让编写Numba提前编译的工作量大了不少,已经不像是在写Python,而是在写C了。

还有一点值得注意。和使用即时编译时一样,自定义的函数在调用其他自定义函数时,必须要加上@njit。所以,会出现一个函数即有@njit,又有@cc.export的情况。

学习使用Numba提前编译时,最主要是要学习Numba是怎么用字符串代表参数类型的。比如,i4是32位整型,u1是8位无符号整型,u1[:, :, :]是三维8位无符号整型,void是无返回值。这些表示可以在官方文档里找到。

以我写的Numba AOT PatchMatch的编译代码为例,我们可以看一看参数类型和返回值类型是怎么描述的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
from numba import njit
from numba.pycc import CC

cc = CC('patch_match_module')


@njit
@cc.export('cal_distance', 'f4(i4, i4, i4, i4, u1[:, :, :], u1[:, :, :], i4)')
def cal_distance(x, y, x2, y2, A_padding, B, p):
...


@njit
@cc.export(
'initialization_loop',
'void(u1[:, :, :], u1[:, :, :], i4[:, :, :], f4[:, :], i4, i4, i4[:, :], i4[:, :], i4)'
)
def initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p):
...


@njit
@cc.export(
'propagation',
'void(i4[:, :, :], i4, i4, i4, i4, f4[:, :], u1[:, :, :], u1[:, :, :], i4, b1)'
)
def propagation(f, x, y, A_h, A_w, dist, A_padding, B, p_size, is_odd):
...


@njit
@cc.export(
'random_search',
'void(i4[:, :, :], i4, i4, i4, i4, f4[:, :], u1[:, :, :], u1[:, :, :], i4, f4)'
)
def random_search(f, x, y, B_h, B_w, dist, A_padding, B, p_size, alpha=0.5):
...


if __name__ == "__main__":
cc.compile()

运行该程序后,在我的电脑上得到了名为patch_match_module.cp37-win_amd64.pyd的模块文件。可以在其他代码里通过import patch_match_module调用编译好的函数了,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import patch_match_module

def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
B_h = np.size(ref, 0)
B_w = np.size(ref, 1)
f, dist, img_padding = initialization(img, ref, A_h, A_w, B_h, B_w, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
patch_match_module.propagation(f, i, j, A_h, A_w, dist,
img_padding, ref, p_size,
False)
patch_match_module.random_search(f, i, j, B_h, B_w, dist,
img_padding, ref, p_size,
0.5)
else:
for i in range(A_h):
for j in range(A_w):
patch_match_module.propagation(f, i, j, A_h, A_w, dist,
img_padding, ref, p_size,
True)
patch_match_module.random_search(f, i, j, B_h, B_w, dist,
img_padding, ref, p_size,
0.5)
return f

加上最后这步提前编译后,PatchMatch的运行时间从3秒降低到了0.6秒多。程序从最开始的180秒降到了0.6秒,几乎快了300倍。而且,如果是处理视频,还可以通过复用前一帧信息来减少迭代次数,进一步缩短每一帧的平均处理时间。能加速这么多,并不是我太强,而是Python实在太慢了。纯Python就不应该用来写科学计算程序。

总结

通过阅读这篇文章,相信大家能根据我这次Python PatchMatch性能优化经历,在不阅读Numba文档的前提下自然而然地学会Numba的用法。我把文章中提到的和Numba性能优化有关的知识点按使用顺序总结一下。

  1. 在面向应用的程序中,不要用Python写科学计算程序。哪怕要写,也要尽可能避免在循环中使用大量计算,而是去调用各个库的向量化计算。
  2. 直接在想优化的函数前加@njit装饰。在待优化函数里使用循环、NumPy函数都是很欢迎的。
  3. 如果碰到了Numba不支持的函数,可以通过两种方式解决:1)用等价的Numba支持的函数代替;2)把不支持和支持的部分分离,只加速支持的部分。
  4. 一个带@njit函数在调用另一个自定义的函数时,那个函数也得加上@njit。因此,应该自底向上地实现Numba即时编译函数。
  5. 如果你接受不了计时编译的编译时间,可以使用提前编译技术。使用提前编译时,主要的工作是给参数和返回值标上正确的类型。

Numba确实很容易上手,只要会加@njit,剩下碰到了什么问题去搜索一下就行。Numba的官方文档很详细,想深入学习的话直接看文档就行了。

本项目的代码仓库为:https://github.com/SingleZombie/Fast-Python-PatchMatch 。在原作者仓库的基础上,我添加了PatchMatch_numba_jit.pyPatchMatch_numba_compile.pyPatchMatch_numba_aot.py这三个文件。它们分别表示即时编译运行程序、提前编译编译程序、提前编译运行程序。