0%

2022年中旬,以扩散模型为核心的图像生成模型将AI绘画带入了大众的视野。实际上,在更早的一年之前,就有了一个能根据文字生成高清图片的模型——VQGAN。VQGAN不仅本身具有强大的图像生成能力,更是传承了前作VQVAE把图像压缩成离散编码的思想,推广了「先压缩,再生成」的两阶段图像生成思路,启发了无数后续工作。

VQGAN生成出的高清图片

在这篇文章中,我将对VQGAN的论文和源码中的关键部分做出解读,提炼出VQGAN中的关键知识点。由于VQGAN的核心思想和VQVAE如出一辙,我不会过多地介绍VQGAN的核心思想,强烈建议读者先去学懂VQVAE,再来看VQGAN。

VQGAN 核心思想

VQGAN的论文名为Taming Transformers for High-Resolution Image Synthesis,直译过来是「驯服Transformer模型以实现高清图像合成」。可以看出,该方法是在用Transformer生成图像。可是,为什么这个模型叫做VQGAN,是一个GAN呢?这是因为,VQGAN使用了两阶段的图像生成方法:

  • 训练时,先训练一个图像压缩模型(包括编码器和解码器两个子模型),再训练一个生成压缩图像的模型。
  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型复原成真实图像。

其中,第一个图像压缩模型叫做VQGAN,第二个压缩图像生成模型是一个基于Transformer的模型。

为什么会有这种乍看起来非常麻烦的图像生成方法呢?要理解VQGAN的这种设计动机,有两条路线可以走。两条路线看待问题的角度不同,但实际上是在讲同一件事。

第一条路线是从Transformer入手。Transformer已经在文本生成领域大展身手。同时,Transformer也在视觉任务中开始崭露头角。相比擅长捕捉局部特征的CNN,Transformer的优势在于它能更好地融合图像的全局信息。可是,Transformer的自注意力操作开销太大,只能生成一些分辨率较低的图像。因此,作者认为,可以综合CNN和Transformer的优势,先用基于CNN的VQGAN把图像压缩成一个尺寸更小、信息更丰富的小图像,再用Transformer来生成小图像。

第二条路线是从VQVAE入手。VQVAE是VQGAN的前作,它有着和VQGAN一模一样两阶段图像生成方法。不同的是,VQVAE没有使用GAN结构,且其配套的压缩图像生成模型是基于CNN的。为提升VQVAE的生成效果,作者提出了两项改进策略:1) 图像压缩模型VQVAE仅使用了均方误差,压缩图像的复原结果较为模糊,可以把图像压缩模型换成GAN;2) 在生成压缩图片这个任务上,基于CNN的图像生成模型比不过Transformer,可以用Transformer代替原来的CNN。

第一条思路是作者在论文的引言中描述的,听起来比较高大上;而第二条思路是读者读过文章后能够自然总结出来的,相对来说比较清晰易懂。如果你已经理解了VQVAE,你能通过第二条思路瞬间弄懂VQGAN的原理。说难听点,VQGAN就是一个改进版的VQVAE。然而,VQGAN的改进非常有效,且使用了若干技巧来实现带约束(比如根据文字描述)的高清图像生成,有非常多地方值得学习。

在下文中,我将先补充VQVAE的背景以方便讨论,再介绍VQGAN论文的四大知识点:VQGAN的设计细节、生成压缩图像的Transformer的设计细节、带约束图像生成的实现方法、高清图像生成的实现方法。

VQVAE 背景知识补充

VQVAE的学习目标是用一个编码器把图像压缩成离散编码,再用一个解码器把图像尽可能地还原回原图像。

通俗来说,VQVAE就是把一幅真实图像压缩成一个小图像。这个小图像和真实图像有着一些相同的性质:小图像的取值和像素值(0-255的整数)一样,都是离散的;小图像依然是二维的,保留了某些空间信息。因此,VQVAE的示意图画成这样会更形象一些:

但小图像和真实图像有一个关键性的区别:与像素值不同,小图像的离散取值之间没有关联。真实图像的像素值其实是一个连续颜色的离散采样,相邻的颜色值也更加相似。比如颜色254和颜色253和颜色255比较相似。而小图像的取值之间是没有关联的,你不能说编码为1与编码为0和编码为2比较相似。由于神经网络不能很好地处理这种离散量,在实际实现中,编码并不是以整数表示的,而是以类似于NLP中的嵌入向量的形式表示的。VAE使用了嵌入空间(又称codebook)来完成整数序号到向量的转换。

为了让任意一个编码器输出向量都变成一个固定的嵌入向量,VQVAE采取了一种离散化策略:把每个输出向量$z_e(x)$替换成嵌入空间中最近的那个向量$z_q(x)$。$z_e(x)$的离散编码就是$z_q(x)$在嵌入空间的下标。这个过程和把254.9的输出颜色值离散化成255的整数颜色值的原理类似。

VQVAE的损失函数由两部分组成:重建误差和嵌入空间误差。

其中,重建误差就是输入和输出之间的均方误差。

嵌入空间误差为解码器输出向量$z_e(x)$和它在嵌入空间对应向量$z_q(x)$的均方误差。

作者在误差中还使用了一种「停止梯度」的技巧。这个技巧在VQGAN中被完全保留,此处就不过多介绍了。

图像压缩模型 VQGAN

回顾了VQVAE的背景知识后,我们来正式认识VQGAN的几个创新点。第一点,图像压缩模型VQVAE被改进成了VQGAN。

一般VAE重建出来出来的图像都会比较模糊。这是因为VAE只使用了均方误差,而均方误差只能保证像素值尽可能接近,却不能保证图像的感知效果更加接近。为此,作者把GAN的一些方法引入VQVAE,改造出了VQGAN。

具体来说,VQGAN有两项改进。第一,作者用感知误差(perceptual loss)代替原来的均方误差作为VQGAN的重建误差。第二,作者引入了GAN的对抗训练机制,加入了一个基于图块的判别器,把GAN误差加入了总误差。

计算感知误差的方法如下:把两幅图像分别输入VGG,取出中间某几层卷积层的特征,计算特征图像之间的均方误差。如果你之前没学过相关知识,请搜索”perceptual loss”。

基于图块的判别器,即判别器不为整幅图输出一个真或假的判断结果,而是把图像拆成若干图块,分别输出每个图块的判断结果,再对所有图块的判断结果取一个均值。这只是GAN的一种改进策略而已,没有对GAN本身做太大的改动。如果你之前没学过相关知识,请搜索”PatchGAN”。

这样,总的误差可以写成:

其中,$\lambda$是控制两种误差比例的权重。作者在论文中使用了一个公式来自适应地设置$\lambda$。和普通的GAN一样,VQGAN的编码器、解码器(即生成器)、codebook会最小化误差,判别器会最大化误差。

用VQGAN代替VQVAE后,重建图片中的模糊纹理清晰了很多。

有了一个保真度高的图像压缩模型,我们可以进入下一步,训练一个生成压缩图像的模型。

基于 Transformer 的压缩图像生成模型

如前所述,经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。因此,之前的VQVAE使用了一个能建模离散颜色的PixelCNN模型作为压缩图像生成模型。但PixelCNN的表现不够优秀。

恰好,功能强大的Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。

Transformer 随机生成句子的过程

为了让Transformer生成图像,我们可以把生成句子的一个个单词,变成生成压缩图像的一个个像素。但是,要让Transformer生成二维图像,还需要克服一个问题:在生成句子时,Transformer会先生成第一个单词,再根据第一个单词生成第二个单词,再根据第一、第二个单词生成第三个单词……。也就是说,Transformer每次会根据之前所有的单词来生成下一单词。而图像是二维数据,没有先后的概念,怎样让像素和文字一样有先后顺序呢?

VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第$i$步,Transformer会根据前$i - 1$个像素$s_{ < i}$生成第$i$个像素$s_i$,

带约束的图像生成

在生成新图像时,我们更希望模型能够根据我们的需求生成图像。比如,我们希望模型生成「一幅优美的风景画」,又或者希望模型在一幅草图的基础上作画。这些需求就是模型的约束。为了实现带约束的图像生成,一般的做法是先有一个无约束(输入是随机数)的图像生成模型,再在这个模型的基础上把一个表示约束的向量插入进图像生成的某一步。

把约束向量插入进模型的方法是需要设计的,插入约束向量的方法往往和模型架构有着密切关系。比如假设一个生成模型是U-Net架构,我们可以把约束向量和当前特征图拼接在一起,输入进U-Net的每一大层。

为了实现带约束的图像生成,VQGAN的作者再次借鉴了Transformer实现带约束文字生成的方法。许多自然语言处理任务都可以看成是带约束的文字生成。比如机器翻译,其实可以看成在给定一种语言的句子的前提下,让模型「随机」生成一个另一种语言的句子。比如要把「简要访问非洲」翻译成英语,我们可以对之前无约束文字生成的Transformer做一些修改。

也就是说,给定约束的句子$c$,在第$i$步,Transformer会根据前$i-1$个输出单词$s_{ < i}$以及$c$生成第$i$个单词$s_i$。表示约束的单词被添加到了所有输出之前,作为这次「随机生成」的额外输入。

上述方法并不是唯一的文字生成方法。这种文字生成方法被称为”decoder-only”。实际上,也有使用一个编码器来额外维护约束信息的文字生成方法。最早的Transformer就用到了带编码器的方法。

我们同样可以把这种思想搬到压缩图像生成里。比如对于MNIST数据集,我们希望模型只生成0~9这些数字中某一个数字的手写图像。也就是说,约束是类别信息,约束的取值是0~9。我们就可以把这个0~9的约束信息添加到Transformer的输入$s_{ < i}$之前,以实现由类别约束的图像生成。

但这种设计又会产生一个新的问题。假设约束条件不能简单地表示成整数,而是一些其他类型的数据,比如语义分割图像,那该怎么办呢?对于这种以图像形式表示的约束,作者的做法是,再训练另一个VQGAN,把约束图像压缩成另一套压缩图片。这一套压缩图片和生成图像的压缩图片有着不同的codebook,就像两种语言有着不同的单词一样。这样,约束图像也变成了一系列的整数,可以用之前的方法进行带约束图像生成了。

生成高清图像

由于Transformer注意力计算的开销很大,作者在所有配置中都只使用了$16 \times 16$的压缩图像,再增大压缩图像尺寸的话计算资源就不够了。而另一方面,每张图像在VQGAN中的压缩比例是有限的。如果图像压缩得过多,则VQGAN的重建质量就不够好了。因此,设边长压缩了$f$倍,则该方法一次能生成的图片的最大尺寸是$16f \times 16f$。在多项实验中,$f=16$的表现都较好。这样算下来,该方法一次只能生成$256 \times 256$的图片。这种尺寸的图片还称不上高清图片。

为了生成更大尺寸的图片,作者先训练好了一套能生成$256 \times 256$的图片的VQGAN+Transformer,再用了一种基于滑动窗口的采样机制来生成大图片。具体来说,作者把待生成图片划分成若干个$16\times16$像素的图块,每个图块对应压缩图像的一个像素。之后,在每一轮生成时,只有待生成图块周围的$16\times16$个图块($256\times256$个像素)会被输入进VQGAN和Transformer,由Transformer生成一个新的压缩图像像素,再把该压缩图像像素解码成图块。(在下面的示意图中,每个方块是一个图块,transformer的输入是$3\times3$个图块)

这个滑动窗口算法不是那么好理解,需要多想一下才能理解它的具体做法。在理解这个算法时,你可能会有这样的问题:上面的示意图中,待生成像素有的时候在最左边,有的时候在中间,有的时候在右边,每次约束它的像素都不一样。这么复杂的约束逻辑怎么编写?其实,Transformer自动保证了每个像素只会由之前的像素约束,而看不到后面的像素。因此,在实现时,只需要把待生成像素框起来,直接用Transformer预测待生成像素即可,不需要编写额外的约束逻辑。

如果你没有学过Transformer的话,理解这部分会有点困难。Transformer可以根据第1~k-1个像素并行地生成第2~k个像素,且保证生成每个像素时不会偷看到后面像素的信息。因此,假设我们要生成第i个像素,其实是预测了所有第2~k个像素的结果,再取出第i个结果,填回待生成图像。

由于论文篇幅有限,作者没有对滑动窗口机制做过多的介绍,也没有讲带约束的滑动窗口是怎么实现的。如果你在理解这一部分时碰到了问题,不用担心,这很正常。稍后我们会在代码阅读章节彻底理解滑动窗口的实现方法。我也是看了代码才看懂此处的做法。

作者在论文中解释了为什么用滑动窗口生成高清图像是合理的。作者先是讨论了两种情况,只要满足这两种情况中的任意一种,拿滑动窗口生成图像就是合理的。第一种情况是数据集的统计规律是几乎空间不变,也就是说训练集图片每$256\times256$个像素的统计规律是类似的。这和我们拿$3\times3$卷积卷图像是因为图像每$3\times3$个像素的统计规律类似的原理是一样的。第二种情况是有空间上的约束信息。比如之前提到的用语义分割图来指导图像生成。由于语义分割也是一张图片,它给每个待生成像素都提供了额外信息。这样,哪怕是用滑动窗口,在局部语义的指导下,模型也足以生成图像了。

若是两种情况都不满足呢?比如在对齐的人脸数据集上做无约束生成。在对齐的人脸数据集里,每张图片中人的五官所在的坐标是差不多的,图片的空间不变性不满足;做无约束生成,自然也没有额外的空间信息。在这种情况下,我们可以人为地添加一个坐标约束,即从左到右、从上到下地给每个像素标一个序号,把每个滑动窗口里的坐标序号做为约束。有了坐标约束后,就还原成了上面的第二种情况,每个像素有了额外的空间信息,基于滑动窗口的方法依然可行。

学完了论文的四大知识点,我们知道VQGAN是怎么根据约束生成高清图像的了。接下来,我们来看看论文的实验部分,看看作者是怎么证明方法的有效性的。

实验

在实验部分,作者先是分别验证了基于Transformer的压缩图像生成模型较以往模型的优越性(4.1节)、VQGAN较以往模型的优越性(4.4节末尾)、使用VQGAN做图像压缩的必要性及相关消融实验(4.3节),再把整个生成方法综合起来,在多项图像生成任务上与以往的图像生成模型做定量对比(4.4节),最后展示了该方法惊艳的带约束生成效果(4.2节)。

在论文4.1节中,作者验证了基于Transformer的压缩图像生成模型的有效性。之前,压缩图像都是使用能输出离散分布的PixelCNN系列模型来生成的。PixelCNN系列的最强模型是PixelSNAIL。为确保公平,作者对比了相同训练时间、相同训练步数下两个网络在不同训练集下的负对数似然(NLL)指标。结果表明,基于Transformer的模型确实训练得更快。

对于直接能建模离散分布的模型来说,NLL就是交叉熵损失函数。

在论文4.4节末尾,作者将VQGAN和之前的图像压缩模型对比,验证了引入感知误差和GAN结构的有效性。作者汇报了各模型重建图像集与原数据集(ImageNet的训练集和验证集)的FID(指标FID是越低越好)。同时,结果也说明,增大codebook的尺寸或者编码种类都能提升重建效果。

在论文4.3节中,作者验证了使用VQGAN的必要性。作者训了两个模型,一个直接让Transformer做真实图像生成,一个用VQGAN把图像边长压缩2倍,再用Transformer生成压缩图像。经比较,使用了VQGAN后,图像生成速度快了10多倍,且图像生成效果也有所提升。

另外,作者还做了有关图像边长压缩比例$f$的消融实验。作者固定让Transformer生成$16 \times 16$的压缩图片,即每次训练时用到的图像尺寸都是$16f \times 16f$。之后,作者训练训练了不同$f$下的模型,用各个模型来生成图片。结果显示$f=16$时效果最好。这是因为,在固定Transformer的生成分辨率的前提下,$f$越小,Transformer的感受野越小。如果Transformer的感受野过小,就学习不到足够的信息。

在论文4.4节中,作者探究了VQGAN+Transformer在多项基准测试(benchmark)上的结果。

首先是语义图像合成(根据语义分割图像来生成)任务。本文的这套方法还不错。

接着是人脸生成任务。这套方法表现还行,但还是比不过专精于某一任务的GAN。

作者还比较了各模型在ImageNet上的生成结果。这一比较的数据量较多,欢迎大家自行阅读原论文。

在论文4.2节中,作者展示了多才多艺的VQGAN+Transformer在各种约束下的图像生成结果。这些图像都是按照默认配置生成的,大小为$256\times256$。

作者还展示了使用了滑动窗口算法后,模型生成的不同分辨率的图像。

本文开头的那张高清图片也来自论文。

总结

VQGAN是一个改进版的VQVAE,它将感知误差和GAN引入了图像压缩模型,把压缩图像生成模型替换成了更强大的Transformer。相比纯种的GAN(如StyleGAN),VQGAN的强大之处在于它支持带约束的高清图像生成。VQGAN借助NLP中”decoder-only”策略实现了带约束图像生成,并使用滑动窗口机制实现了高清图像生成。虽然在某些特定任务上VQGAN还是落后于其他GAN,但VQGAN的泛化性和灵活性都要比纯种GAN要强。它的这些潜力直接促成了Stable Diffusion的诞生。

如果你是读完了VQVAE再来读的VQGAN,为了完全理解VQGAN,你只需要掌握本文提到的4个知识点:VQVAE到VQGAN的改进方法、使用Transformer做图像生成的方法、使用”decoder-only”策略做带约束图像生成的方法、用滑动滑动窗口生成任意尺寸的图片的思想。

代码阅读

在代码阅读章节中,我将先简略介绍官方源码的项目结构以方便大家学习,再介绍代码中的几处核心代码。具体来说,我会介绍模型是如何组织配置文件的、模型的定义代码在哪、训练代码在哪、采样代码在哪,同时我会主要分析VQGAN的结构、Transformer的结构、损失函数、滑动窗口采样算法这几部分的代码。

官方源码地址:https://github.com/CompVis/taming-transformers。

官方的Git仓库里有很多很大的图片,且git记录里还藏了一些很大的数据,整个Git仓库非常大。如果你的网络不好,建议以zip形式下载仓库,或者只把代码部分下载下来。

项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
├─assets
├─configs
├─scripts
└─taming
├─data
│ └─conditional_builder
├─models
└─modules
├─diffusionmodules
├─discriminator
├─losses
├─misc
├─transformer
└─vqvae

configs目录下存放的是模型配置文件。VQGAN和Transformer的模型配置是分开来放的。每个模型配置文件都会指向一个Python模型类,比如taming.models.vqgan.VQModel,配置里的参数就是模型类的初始化参数。我们可用通过阅读配置文件找到模型的定义位置。

运行脚本包括根目录下的main.pyscripts文件夹下的脚本。main.py是用于训练的。scripts文件夹下有各种采样脚本和数据集可视化脚本。

taming是源代码的主目录。其data子文件夹下放置了各数据集的预处理代码,models放置了VQGAN和Transformer PyTorch模型的定义代码,modules则放置了模型中用到的模块,主要包括VQGAN编码解码模块(diffusionmodules)、判别器模块(discriminator)、误差模块(losses)、Transformer模块(transformer)、codebook模块(vqvae)。

VQGAN 模型结构

打开configs\faceshq_vqgan.yaml,我们能够找到高清人脸生成任务使用的VQGAN模型配置。我们来学习一下这个模型的定义方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
...

lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
...

从配置文件的target字段中,我们知道VQGAN定义在模块taming.models.vqgan.VQModel中。我们可以打开taming\models\vqgan.py这个文件,查看其中VQModel类的代码。

首先先看一下初始化函数。初始化函数主要是初始化了encoderdecoderlossquantize这几个模块,我们可以从文件开头的import语句中找到这几个模块的定义位置。不过,先不急,我们来继续看一下模型的前向传播函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer

class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap, sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.image_key = image_key
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor

模型的前向传播逻辑非常清晰。self.encoder可以把一张图片变为特征,self.decoder可以把特征变回图片。self.quant_convpost_quant_conv则分别完成了编码器到codebook、codebook到解码器的通道数转换。self.quantize实现了VQVAE和VQGAN中那个找codebook里的最近邻、替换成最近邻的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info

def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

def forward(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff

接下来,我们再看一看VQGAN的各个模块的定义。编码器和解码器的定义都可以在taming\modules\diffusionmodules\model.py里找到。VQGAN使用的编码器和解码器基于DDPM论文中的U-Net架构(而此架构又可以追溯到PixelCNN++的模型架构)。相比于最经典的U-Net,此U-Net每一层由若干个残差块和若干个自注意力块构成。为了把这个U-Net用到VQGAN里,U-Net的下采样部分和上采样部分被拆开,分别做成了VQGAN的编码器和解码器。

此处代码过长,我就只贴出部分关键代码了。以下是编码器的__init__函数和forward函数的关键代码。self.down存储了U-Net各层的模块。对于第i层,down[i].block是所有残差块,down[i].attn是所有自注意力块,down[i].downsample是下采样操作。它们在forward里会被依次调用。解码器的结构与之类似,只不过下采样变成了上采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, **ignore_kwargs):
super().__init__()
...
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)

...


def forward(self, x):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
...

return h

之后,我们再看看离散化层的代码,即把编码器的输出变成codebook里的嵌入的实现代码。作者在taming\modules\vqvae\quantize.py中提供了VQVAE原版的离散化操作以及若干个改进过的离散化操作。我们就来看一下原版的离散化模块VectorQuantizer是怎么实现的。

离散化模块的初始化非常简洁,主要是初始化了一个嵌入层。

1
2
3
4
5
6
7
8
9
10
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta

self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

在前向传播时,作者先是算出了编码器输出z和所有嵌入的距离d,再用argmin算出了最近邻嵌入的下标min_encodings,最后根据下标取出解码器输入z_q。同时,该函数还计算了其他几个可能用到的量,比如和codebook有关的误差 loss。注意,在计算lossz_q时,作者都使用到了停止梯度算子(.detach())。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward(self, z):
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())

## could possible replace this here
# #\start...
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)

min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)

z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
#.........\end


# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)

# preserve gradients
z_q = z + (z_q - z).detach()

# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()

return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

VQGAN的三个主要模块已经看完了。最后,我们来看一下误差的定义。误差的定义在taming\modules\losses\vqperceptual.pyVQLPIPSWithDiscriminator类里。误差类名里的LPIPS(Learned Perceptual Image Patch Similarity,学习感知图像块相似度)就是感知误差的全称,”WithDiscriminator”表示误差是带了判定器误差的。我们来把这两类误差分别看一下。

说实话,这个误差模块乱得一塌糊涂,一边自己在算误差,一边又维护了codebook误差和重建误差的权重,最后会把自己维护的两个误差和其他误差合在一起输出。功能全部耦合在一起。我们就跳过这个类的实现细节,主要关注self.perceptual_lossself.discriminator是怎么调用其他模块的。

1
2
3
4
5
6
7
8
9
10
from taming.modules.losses.lpips import LPIPS
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init

class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, ...):
super().__init__()

self.perceptual_loss = LPIPS().eval()

self.discriminator = NLayerDiscriminator...

感知误差模块在taming\modules\losses\vqperceptual.py文件里。这个文件来自GitHub项目 PerceptualSimilarity。

感知误差可以简单地理解为两张图片在VGG中几个卷积层输出的误差的加权和。加权的权重是可以学习的。作者使用的是已经学习好的感知误差。感知误差的初始化函数如下。其中,self.lin0等模块就是算权重的模块,self.net是VGG。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False

在算误差时,先是把图像inputtarget都输入进VGG,获取各层输出outs0, outs1,再求出两个图像的输出的均方误差diffs,最后用lins给各层误差加权,求和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val

GAN的判别器写在taming\modules\discriminator\model.py文件里。这个文件来自GitHub上的 pytorch-CycleGAN-and-pix2pix 项目。这个判别器非常简单,就是一个全卷积网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d

kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)

def forward(self, input):
"""Standard forward."""
return self.main(input)

Transformer 模型结构

此方法使用的Transformer是GPT2。我们先看一下该项目封装Transformer的模型类taming.models.cond_transformer.Net2NetTransformer,再稍微看一下GPT类taming.modules.transformer.mingpt.GPT的具体实现。

Net2NetTransformer主要是实现了论文中提到的带约束生成。它会把输入x和约束c分别用一个VQGAN转成压缩图像,把图像压扁成一维,再调用GPT。我们来看一下这个类的主要内容。

初始化函数主要是初始化了输入图像的VQGAN self.first_stage_model、约束图像的VQGAN self.cond_stage_model、Transformer self.transformer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Net2NetTransformer(pl.LightningModule):
def __init__(self,
transformer_config,
first_stage_config,
cond_stage_config,
permuter_config=None,
ckpt_path=None,
ignore_keys=[],
first_stage_key="image",
cond_stage_key="depth",
downsample_cond_size=-1,
pkeep=1.0,
sos_token=0,
unconditional=False,
):
super().__init__()
self.be_unconditional = unconditional
self.sos_token = sos_token
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.init_first_stage_from_ckpt(first_stage_config)
self.init_cond_stage_from_ckpt(cond_stage_config)
if permuter_config is None:
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
self.permuter = instantiate_from_config(config=permuter_config)
self.transformer = instantiate_from_config(config=transformer_config)

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.downsample_cond_size = downsample_cond_size
self.pkeep = pkeep

def init_first_stage_from_ckpt(self, config):
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
self.first_stage_model = model

def init_cond_stage_from_ckpt(self, config):
...
self.cond_stage_model = ...

模型的前向传播函数如下。一开始,函数调用encode_to_zencode_to_c,根据self.cond_stage_modelself.first_stage_model把约束图像和输入图像编码成压扁至一维的压缩图像。之后函数做了一个类似Dropout的操作,根据self.pkeep随机替换掉约束编码。最后,函数把约束编码和输入编码拼接起来,使用通常方法调用Transformer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def forward(self, x, c):
# one step to produce the logits
_, z_indices = self.encode_to_z(x)
_, c_indices = self.encode_to_c(c)

if self.training and self.pkeep < 1.0:
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
device=z_indices.device))
mask = mask.round().to(dtype=torch.int64)
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
a_indices = mask*z_indices+(1-mask)*r_indices
else:
a_indices = z_indices

cz_indices = torch.cat((c_indices, a_indices), dim=1)

# target includes all sequence elements (no need to handle first one
# differently because we are conditioning)
target = z_indices
# make the prediction
logits, _ = self.transformer(cz_indices[:, :-1])
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
logits = logits[:, c_indices.shape[1]-1:]

return logits, target

GPT2的结构不是本文的重点,我们就快速把模型结构过一遍了。GPT2的模型定义在taming.modules.transformer.mingpt.GPT里。GPT2的结构并不复杂,就是一个只有解码器的Transformer。前向传播时,数据先通过嵌入层self.tok_emb,再经过若干个Transformer模块self.blocks,最后过一个LayerNorm层self.ln_f和线性层self.head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GPT(nn.Module):

def forward(self, idx, embeddings=None, targets=None):
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector

if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)

t = token_embeddings.shape[1]
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

return logits, loss

每个Transformer块就是非常经典的自注意力加全连接层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(), # nice
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)

def forward(self, x, layer_past=None, return_present=False):
# TODO: check that training still works
if return_present: assert not self.training
# layer past: tuple of length two with B, nh, T, hs
attn, present = self.attn(self.ln1(x), layer_past=layer_past)

x = x + attn
x = x + self.mlp(self.ln2(x))
if layer_past is not None or return_present:
return x, present
return x

基于滑动窗口的带约束图像生成

看完了所有模型的结构,我们最后来学习一下论文中没能详细介绍的滑动窗口算法。在scripts\taming-transformers.ipynb里有一个采样算法的最简实现,我们就来学习一下这份代码。

这份代码可以根据一幅语义分割图像来生成高清图像。一开始,代码会读入模型和语义分割图像。大致的代码为:

1
2
3
4
5
6
7
from taming.models.cond_transformer import Net2NetTransformer
model = Net2NetTransformer(**config.model.params)
from PIL import Image
import numpy as np
segmentation_path = "data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png"
segmentation = Image.open(segmentation_path)
...


之后,代码把约束图像用对应的VQGAN编码进压缩空间,得到c_indices。由于待生成图像为空,我们可以随便生成一个待生成图像的压缩图像z_indices,代码中使用了randint初始化待生成的压缩图像。

1
2
3
4
5
6
7
8
c_code, c_indices = model.encode_to_c(segmentation)
z_indices = torch.randint(codebook_size, z_indices_shape, device=model.device)

idx = z_indices
idx = idx.reshape(z_code_shape[0],z_code_shape[2],z_code_shape[3])

cidx = c_indices
cidx = cidx.reshape(c_code.shape[0],c_code.shape[2],c_code.shape[3])

最后就是最关键的滑动窗口采样部分了。我们先稍微浏览一遍代码,再详细地一行一行看过去。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
temperature = 1.0
top_k = 100

for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

一开始的temperaturetop_k是得到logit后的采样参数,和滑动窗口算法无关。
1
2
temperature = 1.0
top_k = 100

进入生成图像循环后,i, j分别表示压缩图像的竖索引和横索引,i_start, i_end, j_start, j_end是滑动窗口上下左右边界。

1
2
3
4
5
6
7
8
for i in range(0, z_code_shape[2]-0):
...
for j in range(0,z_code_shape[3]-0):
...
i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

为了获取这四个滑动窗口的范围,代码用了若干条件语句计算待生成像素在滑动窗口里的相对位置local_i, local_j

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

得到了滑动窗口的边界后,代码用滑动窗口从约束图像的压缩图像和待生成图像的压缩图像上各取出一个图块,并拼接起来。

1
2
3
4
5
patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)

之后,只需要把拼接的图块直接输入进Transformer,得到输出logits,再用local_i,local_j去输出图块的对应位置取出下一个压缩图像像素的概率分布,就可以随机生成下一个压缩图像像素了。如前文所述,Transformer类会把二维的图块压扁到一维,输入进GPT。同时,GPT会自动保证前面的像素看不到后面的像素,我们不需要人为地指定约束像素。这个地方的调用逻辑其实非常简单。

1
2
3
4
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

最后只要从logits里采样,把采样出的压缩图像像素填入idx,就完成了一步生成。

1
2
3
4
5
6
7
logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

反复执行循环,就能将压缩图像生成完毕。最后将压缩图像过一遍VQGAN的解码器即可得到最终的生成图像。

1
2
x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

参考资料

VQGAN论文:https://arxiv.org/abs/2012.09841

VQGAN GitHub:https://github.com/CompVis/taming-transformers

如果你需要补充学习早期工作,欢迎阅读我之前的文章。

Transformer解读

PixelCNN解读

VQVAE解读

在这篇文章中,我将详细地介绍一个英中翻译 Transformer 的 PyTorch 实现。这篇文章会完整地展示一个深度学习项目的搭建过程,从数据集准备,到模型定义、训练。这篇文章不仅会讲解如何把 Transformer 的论文翻译成代码,还会讲清楚代码实现中的诸多细节,并分享我做实验时碰到的种种坑点。相信初学者能够从这篇文章中学到丰富的知识。

项目网址: https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/Transformer

如果你对 Transformer 的论文不熟,欢迎阅读我之前的文章:Attention Is All You Need (Transformer) 论文精读

数据集准备

我在 https://github.com/P3n9W31/transformer-pytorch 项目中找到了一个较小的中英翻译数据集。数据集只有几KB大小,中英词表只有10000左右,比较适合做Demo。如果要实现更加强大实用的模型,则需要换更大的数据集。但相应地,你要多花费更多的时间来训练。

我在代码仓库中提供了data_load.py文件。执行这个文件后,实验所需要的数据会自动下载到项目目录的data文件夹下。

该数据集由cn.txt, en.txt, cn.txt.vocab.tsv, en.txt.vocab.tsv这四个文件组成。前两个文件包含相互对应的中英文句子,其中中文已做好分词,英文全为小写且标点已被分割好。后两个文件是预处理好的词表。语料来自2000年左右的中国新闻,其第一条的中文及其翻译如下:

1
2
目前 粮食 出现 阶段性 过剩 , 恰好 可以 以 粮食 换 森林 、 换 草地 , 再造 西部 秀美 山川 。
the present food surplus can specifically serve the purpose of helping western china restore its woodlands , grasslands , and the beauty of its landscapes .

词表则统计了各个单词的出现频率。通过使用词表,我们能实现单词和序号的相互转换(比如中文里的5号对应“的”字,英文里的5号对应”the”)。词表的前四个单词是特殊字符,分别为填充字符、频率太少没有被加入词典的词语、句子开始字符、句子结束字符。

1
2
3
4
5
6
7
8
<PAD>	1000000000
<UNK> 1000000000
<S> 1000000000
</S> 1000000000
的 8461
是 2047
和 1836
在 1784
1
2
3
4
5
6
7
8
<PAD>	1000000000
<UNK> 1000000000
<S> 1000000000
</S> 1000000000
the 13680
and 6845
of 6259
to 4292

只要运行一遍data_load.py下好数据后,我们待会就能用load_train_data()来获取已经打成batch的训练数据,并用API获取cn2idx, idx2cn, en2idx, idx2en这四个描述中英文序号与单词转换的词典。我们会在之后的训练代码里见到它们的用法。

Transformer 模型

准备好数据后,接下来就要进入这个项目最重要的部分——Transformer 模型实现了。我将按照代码的执行顺序,从前往后,自底向上地介绍 Transformer 的各个模块:Positional Encoding, MultiHeadAttention, Encoder & Decoder, 最后介绍如何把各个模块拼到一起。在这个过程中,我还会着重介绍一个论文里没有提及,但是代码实现时非常重要的一个细节——<pad>字符的处理。

说实话,用 PyTorch 实现 Transformer 没有什么有变数的地方,大家的代码写得都差不多,我也是参考着别人的教程写的。但是,Transformer 的代码实现中有很多坑。大部分人只会云淡风轻地介绍一下最终的代码成品,不会去讲他们 debug 耗费了多少时间,哪些地方容易出错。而我会着重讲一下代码中的一些细节,以及我碰到过的问题。

Positional Encoding

模型一开始是一个 Embedding 层加一个 Positional Encoding。Embedding 在 PyTorch 里已经有实现了,且不是文章的创新点,我们就直接来看 Positional Encoding 的写法。

求 Positional Encoding,其实就是求一个二元函数的许多函数值构成的矩阵。对于二元函数$PE(pos, i)$,我们要求出$pos \in [0, seqlen - 1], i \in [0, d_{model} - 1]$时所有的函数值,其中,$seqlen$是该序列的长度,$d_{model}$是每一个词向量的长度。

理论上来说,每个句子的序列长度$seqlen$是不固定的。但是,我们可以提前预处理一个$seqlen$很大的 Positional Encoding 矩阵 。每次有句子输入进来,根据这个句子的序列长度,去预处理好的矩阵里取一小段出来即可。

这样,整个类的实现应该如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class PositionalEncoding(nn.Module):

def __init__(self, d_model: int, max_seq_len: int):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(1, max_seq_len, d_model)

self.register_buffer('pe', pe, False)

def forward(self, x: torch.Tensor):
n, seq_len, d_model = x.shape
pe: torch.Tensor = self.pe
assert seq_len <= pe.shape[1]
assert d_model == pe.shape[2]
rescaled_x = x * d_model**0.5
return rescaled_x + pe[:, 0:seq_len, :]

代码中有不少需要讲解的部分。首先,先看一下预处理好的矩阵pe是怎么在__init__中算出来的。pe可以很直接地用两层循环算出来。由于这段预处理代码只会执行一次,相对于冗长的训练时间,哪怕生成pe的代码性能差一点也没事。然而,作为一个编程高手,我准备秀一下如何用并行的方法求出pe

为了并行地求pe,我们要初始化一个二维网格,表示自变量$pos, i$。生成网格可以用下面的代码实现。(由于$i$要分奇偶讨论,$i$的个数是$\frac{d_{model}}{2}$)

1
2
3
i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)

torch.meshgrid用于生成网格。比如torch.meshgrid([0, 1], [0, 1])就可以生成[[(0, 0), (0, 1)], [(1, 0), (1, 1)]]这四个坐标构成的网格。不过,这个函数会把坐标的两个分量分别返回。比如:

1
2
3
i, j = torch.meshgrid([0, 1], [0, 1])
# i: [[0, 0], [1, 1]]
# j: [[0, 1], [0, 1]]

利用这个函数的返回结果,我们可以把pos, two_i套入论文的公式,并行地分别算出奇偶位置的 PE 值。

1
2
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))

有了奇偶处的值,现在的问题是怎么把它们优雅地拼到同一个维度上。我这里先把它们堆成了形状为seq_len, d_model/2, 2的一个张量,再把最后一维展平,就得到了最后的pe矩阵。这一操作等于新建一个seq_len, d_model形状的张量,再把奇偶位置处的值分别填入。

1
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(1, max_seq_len, d_model)

最后,要注意一点。只用 self.pe = pe 记录这个量是不够好的。我们最好用 self.register_buffer('pe', pe, False) 把这个量登记成 torch.nn.Module 的一个存储区(这一步会自动完成self.pe = pe)。这里涉及到 PyTorch 的一些知识了。

PyTorch 的 Module 会记录两类参数,一类是 parameter 可学习参数,另一类是 buffer 不可学习的参数。把变量登记成 buffer 的最大好处是,在使用 model.to(device) 把一个模型搬到另一个设备上时,所有 parameterbuffer 都会自动被搬过去。另外,bufferparameter 一样,也可以被记录到 state_dict 中,并保存到文件里。register_buffer 的第三个参数决定了是否将变量加入 state_dict。由于 pe 可以直接计算,不需要记录,可以把这个参数设成 False

预处理好 pe 后,用起来就很方便了。每次读取输入的序列长度,从中取一段出来即可。

另外,Transformer 给嵌入层乘了个系数$\sqrt{d_{model}}$。为了方便起见,我把这个系数放到了 PositionalEncoding 类里面。

1
2
3
4
5
6
7
def forward(self, x: torch.Tensor):
n, seq_len, d_model = x.shape
pe: torch.Tensor = self.pe
assert seq_len <= pe.shape[1]
assert d_model == pe.shape[2]
rescaled_x = x * d_model**0.5
return rescaled_x + pe[:, 0:seq_len, :]

Scaled Dot-Product Attention

下一步是多头注意力层。为了实现多头注意力,我们先要实现 Transformer 里经典的注意力计算。而在讲注意力计算之前,我还要补充一下 Transformer 中有关 mask 的一些知识。

Transformer 里的 mask

Transformer 最大的特点就是能够并行训练。给定翻译好的第1~n个词语,它默认会并行地预测第2~(n+1)个下一个词语。为了模拟串行输出的情况,第$t$个词语不应该看到第$t+1$个词语之后的信息。

输入信息 输出
(y1, —, —, —) y2
(y1, y2, —, —) y3
(y1, y2, y3, —) y4
(y1, y2, y3, y4) y5

为了实现这一功能,Transformer 在Decoder里使用了掩码。掩码取1表示这个地方的数是有效的,取0表示这个地方的数是无效的。Decoder 里的这种掩码应该是一个上三角全1矩阵。

掩码是在注意力计算中生效的。对于掩码取0的区域,其softmax前的$QK^T$值取负无穷。这是因为,对于softmax

令$x_i=-\infty$可以让它在 softmax 的分母里不产生任何贡献。

以上是论文里提到的 mask,它用来模拟 Decoder 的串行推理。而在代码实现中,还有其他地方会产生 mask。在生成一个 batch 的数据时,要给句子填充 <pad>。这个特殊字符也没有实际意义,不应该对计算产生任何贡献。因此,有 <pad> 的地方的 mask 也应该为0。之后讲 Transformer 模型类时,我会介绍所有的 mask 该怎么生成,这里我们仅关注注意力计算是怎么用到 mask 的。

注意力计算

补充完了背景知识,我们来看注意力计算的实现代码。由于注意力计算没有任何的状态,因此它应该写成一个函数,而不是一个类。我们可以轻松地用 PyTorch 代码翻译注意力计算的公式。(注意,我这里的 mask 表示哪些地方要填负无穷,而不是像之前讲的表示哪些地方有效)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
MY_INF = 1e12

def attention(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
'''
Note: The dtype of mask must be bool
'''
# q shape: [n, heads, q_len, d_k]
# k shape: [n, heads, k_len, d_k]
# v shape: [n, heads, k_len, d_v]
assert q.shape[-1] == k.shape[-1]
d_k = k.shape[-1]
# tmp shape: [n, heads, q_len, k_len]
tmp = torch.matmul(q, k.transpose(-2, -1)) / d_k**0.5
if mask is not None:
tmp.masked_fill_(mask, -MY_INF)
tmp = F.softmax(tmp, -1)
# tmp shape: [n, heads, q_len, d_v]
tmp = torch.matmul(tmp, v)
return tmp

这里有一个很坑的地方。引入了 <pad> 带来的 mask 后,会产生一个新的问题:可能一整行数据都是失效的,softmax 用到的所有 $x_i$ 可能都是负无穷。

这个数是没有意义的。如果用torch.inf来表示无穷大,就会令exp(torch.inf)=0,最后 softmax 结果会出现 NaN,代码大概率是跑不通的。

但是,大多数 PyTorch Transformer 教程压根就没提这一点,而他们的代码又还是能够跑通。拿放大镜仔细对比了代码后,我发现,他们的无穷大用的不是 torch.inf,而是自己随手设的一个极大值。这样,exp(-MY_INF)得到的不再是0,而是一个极小值。softmax 的结果就会等于分母的项数,而不是 NaN,不会有数值计算上的错误。

Multi-Head Attention

有了注意力计算,就可以实现多头注意力层了。多头注意力层是有学习参数的,它应该写成一个类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class MultiHeadAttention(nn.Module):

def __init__(self, heads: int, d_model: int, dropout: float = 0.1):
super().__init__()

assert d_model % heads == 0
# dk == dv
self.d_k = d_model // heads
self.heads = heads
self.d_model = d_model
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
# batch should be same
assert q.shape[0] == k.shape[0]
assert q.shape[0] == v.shape[0]
# the sequence length of k and v should be aligned
assert k.shape[1] == v.shape[1]

n, q_len = q.shape[0:2]
n, k_len = k.shape[0:2]
q_ = self.q(q).reshape(n, q_len, self.heads, self.d_k).transpose(1, 2)
k_ = self.k(k).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)
v_ = self.v(v).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)

attention_res = attention(q_, k_, v_, mask)
concat_res = attention_res.transpose(1, 2).reshape(
n, q_len, self.d_model)
concat_res = self.dropout(concat_res)

output = self.out(concat_res)
return output

这段代码一处很灵性的地方。在 Transformer 的论文中,多头注意力是先把每个词的表示拆成$h$个头,再对每份做投影、注意力,最后拼接起来,再投影一次。其实,拆开与拼接操作是多余的。我们可以通过一些形状上的操作,等价地实现拆开与拼接,以提高运行效率。

具体来说,我们可以一开始就让所有头的数据经过同一个线性层。之后在做注意力之前把头和序列数这两维转置一下。这两步操作和拆开来做投影、注意力是等价的。做完了注意力操作之后,再把两个维度转置回来,这和拼接操作是等价的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def forward(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
# batch should be same
assert q.shape[0] == k.shape[0]
assert q.shape[0] == v.shape[0]
# the sequence length of k and v should be aligned
assert k.shape[1] == v.shape[1]

n, q_len = q.shape[0:2]
n, k_len = k.shape[0:2]
q_ = self.q(q).reshape(n, q_len, self.heads, self.d_k).transpose(1, 2)
k_ = self.k(k).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)
v_ = self.v(v).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)

attention_res = attention(q_, k_, v_, mask)
concat_res = attention_res.transpose(1, 2).reshape(
n, q_len, self.d_model)
concat_res = self.dropout(concat_res)

output = self.out(concat_res)
return output

前馈网络

前馈网络太简单了,两个线性层,没什么好说的。注意内部那个隐藏层的维度大小$d_{ff}$会比$d_{model}$更大一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
class FeedForward(nn.Module):

def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.layer1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.layer2 = nn.Linear(d_ff, d_model)

def forward(self, x):
x = self.layer1(x)
x = self.dropout(F.relu(x))
x = self.layer2(x)
return x

Encoder & Decoder

准备好一切组件后,就可以把模型一层一层搭起来了。先搭好每个 Encoder 层和 Decoder 层,再拼成 Encoder 和 Decoder。

Encoder 层和 Decoder 层的结构与论文中的描述一致,且每个子层后面都有一个 dropout,和上一层之间使用了残差连接。归一化的方法是 LayerNorm。顺带一提,不仅是这些层,前面很多子层的计算中都加入了 dropout。

再提一句 mask。由于 encoder 和 decoder 的输入不同,它们的填充情况不同,产生的 mask 也不同。后文会展示这些 mask 的生成方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class EncoderLayer(nn.Module):

def __init__(self,
heads: int,
d_model: int,
d_ff: int,
dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(heads, d_model, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

def forward(self, x, src_mask: Optional[torch.Tensor] = None):
tmp = self.self_attention(x, x, x, src_mask)
tmp = self.dropout1(tmp)
x = self.norm1(x + tmp)
tmp = self.ffn(x)
tmp = self.dropout2(tmp)
x = self.norm2(x + tmp)
return x


class DecoderLayer(nn.Module):

def __init__(self,
heads: int,
d_model: int,
d_ff: int,
dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(heads, d_model, dropout)
self.attention = MultiHeadAttention(heads, d_model, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

def forward(self,
x,
encoder_kv: torch.Tensor,
dst_mask: Optional[torch.Tensor] = None,
src_dst_mask: Optional[torch.Tensor] = None):
tmp = self.self_attention(x, x, x, dst_mask)
tmp = self.dropout1(tmp)
x = self.norm1(x + tmp)
tmp = self.attention(x, encoder_kv, encoder_kv, src_dst_mask)
tmp = self.dropout2(tmp)
x = self.norm2(x + tmp)
tmp = self.ffn(x)
tmp = self.dropout3(tmp)
x = self.norm3(x + tmp)
return x

Encoder 和 Decoder 就是在所有子层前面加了一个嵌入层、一个位置编码,再把多个子层堆起来了而已,其他输入输出照搬即可。注意,我们可以给嵌入层输入pad_idx参数,让<pad>的计算不对梯度产生贡献。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Encoder(nn.Module):

def __init__(self,
vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 120):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
self.pe = PositionalEncoding(d_model, max_seq_len)
self.layers = []
for i in range(n_layers):
self.layers.append(EncoderLayer(heads, d_model, d_ff, dropout))
self.layers = nn.ModuleList(self.layers)
self.dropout = nn.Dropout(dropout)

def forward(self, x, src_mask: Optional[torch.Tensor] = None):
x = self.embedding(x)
x = self.pe(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, src_mask)
return x


class Decoder(nn.Module):

def __init__(self,
vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 120):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
self.pe = PositionalEncoding(d_model, max_seq_len)
self.layers = []
for i in range(n_layers):
self.layers.append(DecoderLayer(heads, d_model, d_ff, dropout))
self.layers = nn.Sequential(*self.layers)
self.dropout = nn.Dropout(dropout)

def forward(self,
x,
encoder_kv,
dst_mask: Optional[torch.Tensor] = None,
src_dst_mask: Optional[torch.Tensor] = None):
x = self.embedding(x)
x = self.pe(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, encoder_kv, dst_mask, src_dst_mask)
return x

Transformer 类

终于,激动人心的时候到来了。我们要把各个子模块组成变形金刚(Transformer)了。先过一遍所有的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Transformer(nn.Module):

def __init__(self,
src_vocab_size: int,
dst_vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 200):
super().__init__()
self.encoder = Encoder(src_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.decoder = Decoder(dst_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.pad_idx = pad_idx
self.output_layer = nn.Linear(d_model, dst_vocab_size)

def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):
# q_pad shape: [n, q_len]
# k_pad shape: [n, k_len]
# q_pad k_pad dtype: bool
assert q_pad.device == k_pad.device
n, q_len = q_pad.shape
n, k_len = k_pad.shape

mask_shape = (n, 1, q_len, k_len)
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)
mask = mask.to(q_pad.device)
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1
mask = mask.to(torch.bool)
return mask

def forward(self, x, y):

src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)
encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

我们一点一点来看。先看初始化函数。初始化函数的输入其实就是 Transformer 模型的超参数。总结一下,Transformer 应该有这些超参数:

  • d_model 模型中大多数词向量表示的维度大小
  • d_ff 前馈网络隐藏层维度大小
  • n_layers 堆叠的 Encoder & Decoder 层数
  • head 多头注意力的头数
  • dropout Dropout 的几率

另外,为了构建嵌入层,要知道源语言、目标语言的词典大小,并且提供pad_idx。为了预处理位置编码,需要提前知道一个最大序列长度。

照着子模块的初始化参数表,把参数归纳到__init__的参数表里即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def __init__(self,
src_vocab_size: int,
dst_vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 200):
super().__init__()
self.encoder = Encoder(src_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.decoder = Decoder(dst_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.pad_idx = pad_idx
self.output_layer = nn.Linear(d_model, dst_vocab_size)

再看一下 forward 函数。forward先预处理好了所有的 mask,再逐步执行 Transformer 的计算:先是通过 Encoder 获得源语言的中间表示encoder_kv,再把它和目标语言y的输入一起传入 Decoder,最后经过线性层输出结果res。由于 PyTorch 的交叉熵损失函数自带了 softmax 操作,这里不需要多此一举。

Transformer 论文提到,softmax 前的那个线性层可以和嵌入层共享权重。也就是说,嵌入和输出前的线性层分别完成了词序号到词嵌入的正反映射,两个操作应该是互逆的。但是,词嵌入矩阵不是一个方阵,它根本不能求逆矩阵。我想破头也没想清楚是怎么让线性层可以和嵌入层共享权重的。网上的所有实现都没有对这个细节多加介绍,只是新建了一个线性层。我也照做了。

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x, y):

src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)
encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

等了很久,现在可以来仔细看一看 mask 的生成方法了。回忆一下,表示该字符是否有效的 mask 有两个来源。第一个是论文里提到的,用于模拟串行推理的 mask;另一个是填充操作的空白字符引入的 mask。generate_mask 用于生成这些 mask。

generate_mask 的输入有 query 句子和 key 句子的 pad mask q_pad, k_pad,它们的形状为[n, seq_len]。若某处为 True,则表示这个地方的字符是<pad>。对于自注意力,query 和 key 都是一样的;而在 Decoder 的第二个多头注意力层中,query 来自目标语言,key 来自源语言。with_left_mask 表示是不是要加入 Decoder 里面的模拟串行推理的 mask,它会在掩码自注意力里用到。

1
2
3
4
def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):

一开始,先取好维度信息,定好张量的形状。在注意力操作中,softmax 前的那个量的形状是 [n, heads, q_len, k_len],表示每一批每一个头的每一个query对每个key之间的相似度。每一个头的mask是一样的。因此,除heads维可以广播外,mask 的形状应和它一样。

1
mask_shape = (n, 1, q_len, k_len)

再新建一个表示最终 mask 的张量。如果不用 Decoder 的那种 mask,就生成一个全零的张量;否则,生成一个上三角为0,其余地方为1的张量。注意,在我的代码中,mask 为 True 或1就表示这个地方需要填负无穷。

1
2
3
4
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)

最后,把有 <pad> 的地方也标记一下。从mask的形状[n, 1, q_len, k_len]可以知道,q_pad 表示哪些行是无效的,k_pad 表示哪些列是无效的。如果query句子的第i个字符是<pad>,则应该令mask[:, :, i, :] = 1; 如果key句子的第j个字符是<pad>,则应该令mask[:, :, :, j] = 1

下面的代码利用了PyTorch的取下标机制,直接并行地完成了mask赋值。

1
2
3
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):
# q_pad shape: [n, q_len]
# k_pad shape: [n, k_len]
# q_pad k_pad dtype: bool
assert q_pad.device == k_pad.device
n, q_len = q_pad.shape
n, k_len = k_pad.shape

mask_shape = (n, 1, q_len, k_len)
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)
mask = mask.to(q_pad.device)
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1
mask = mask.to(torch.bool)
return mask

看完了mask的生成方法后,我们回到前一步,看看mask会在哪些地方被调用。

在 Transformer 中,有三类多头注意力层,它们的 mask 也不同。Encoder 的多头注意力层的 query 和 key 都来自源语言;Decoder 的第一个多头注意力层的 query 和 key 都来自目标语言;Decoder 的第二个多头注意力层的 query 来自目标语言, key 来自源语言。另外,Decoder 的第一个多头注意力层要加串行推理的那个 mask。按照上述描述生成mask即可。

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x, y):
src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)

encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

到此,Transfomer 模型总算编写完成了。

这里再帮大家排一个坑。PyTorch的官方Transformer中使用了下面的参数初始化方式。但是,实际测试后,不知道为什么,我发现使用这种初始化会让模型训不起来。

1
2
3
4
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

我去翻了翻PyTorch的Transformer示例,发现官方的示例根本没用到Transformer,而是用子模块nn.TransformerDecoder, nn.TransformerEncoder自己搭了一个新的Transformer。这些子模块其实都有自己的init_weights方法。看来官方都信不过自己的Transformer,这个Transformer类的初始化方法就有问题。

在我们的代码中,我们不必手动对参数初始化。PyTorch对每个线性层默认的参数初始化方式就够好了。

训练

准备好了模型、数据集后,剩下的工作非常惬意,只要随便调用一下就行了。训练的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import time

from dldemos.Transformer.data_load import (get_batch_indices, load_cn_vocab,
load_en_vocab, load_train_data,
maxlen)
from dldemos.Transformer.model import Transformer

# Config
batch_size = 64
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60
PAD_ID = 0


def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()
# X: en
# Y: cn
Y, X = load_train_data()

print_interval = 100

model = Transformer(len(en2idx), len(cn2idx), PAD_ID, d_model, d_ff,
n_layers, heads, dropout_rate, maxlen)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)

citerion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
tic = time.time()
cnter = 0
for epoch in range(n_epochs):
for index, _ in get_batch_indices(len(X), batch_size):
x_batch = torch.LongTensor(X[index]).to(device)
y_batch = torch.LongTensor(Y[index]).to(device)
y_input = y_batch[:, :-1]
y_label = y_batch[:, 1:]
y_hat = model(x_batch, y_input)

y_label_mask = y_label != PAD_ID
preds = torch.argmax(y_hat, -1)
correct = preds == y_label
acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask)

n, seq_len = y_label.shape
y_hat = torch.reshape(y_hat, (n * seq_len, -1))
y_label = torch.reshape(y_label, (n * seq_len, ))
loss = citerion(y_hat, y_label)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()

if cnter % print_interval == 0:
toc = time.time()
interval = toc - tic
minutes = int(interval // 60)
seconds = int(interval % 60)
print(f'{cnter:08d} {minutes:02d}:{seconds:02d}'
f' loss: {loss.item()} acc: {acc.item()}')
cnter += 1

model_path = 'dldemos/Transformer/model.pth'
torch.save(model.state_dict(), model_path)

print(f'Model saved to {model_path}')


if __name__ == '__main__':
main()

所有的超参数都写在代码开头。在模型结构上,我使用了和原论文一样的超参数。

1
2
3
4
5
6
7
8
9
10
# Config
batch_size = 64
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60
PAD_ID = 0

之后,进入主函数。一开始,我们调用load_data.py提供的API,获取中英文序号到单词的转换词典,并获取已经打包好的训练数据。

1
2
3
4
5
6
7
def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()
# X: en
# Y: cn
Y, X = load_train_data()

接着,我们用参数初始化好要用到的对象,比如模型、优化器、损失函数。

1
2
3
4
5
6
7
8
9
10
11
print_interval = 100

model = Transformer(len(en2idx), len(cn2idx), PAD_ID, d_model, d_ff,
n_layers, heads, dropout_rate, maxlen)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)

citerion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
tic = time.time()
cnter = 0

再然后,进入训练循环。我们从X, Y里取出源语言和目标语言的序号数组,输入进模型里。别忘了,Transformer可以并行训练。我们给模型输入目标语言前n-1个单词,用第2到第n个单词作为监督标签。

1
2
3
4
5
6
7
for epoch in range(n_epochs):
for index, _ in get_batch_indices(len(X), batch_size):
x_batch = torch.LongTensor(X[index]).to(device)
y_batch = torch.LongTensor(Y[index]).to(device)
y_input = y_batch[:, :-1]
y_label = y_batch[:, 1:]
y_hat = model(x_batch, y_input)

得到模型的预测y_hat后,我们可以把输出概率分布中概率最大的那个单词作为模型给出的预测单词,算一个单词预测准确率。当然,我们要排除掉<pad>的影响。

1
2
3
4
y_label_mask = y_label != PAD_ID
preds = torch.argmax(y_hat, -1)
correct = preds == y_label
acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask)

我们最后算一下loss,并执行梯度下降,训练代码就写完了。为了让训练更稳定,不出现梯度过大的情况,我们可以用torch.nn.utils.clip_grad_norm_(model.parameters(), 1)裁剪梯度。

1
2
3
4
5
6
7
8
9
n, seq_len = y_label.shape
y_hat = torch.reshape(y_hat, (n * seq_len, -1))
y_label = torch.reshape(y_label, (n * seq_len, ))
loss = citerion(y_hat, y_label)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()

实验

在本项目的实验中,使用单卡3090,约10分钟就能完成训练。最终的训练准确率可以到达90%以上。

1
00006300 12:12 loss: 0.43494755029678345 acc: 0.9049844145774841

该数据集没有提供测试集(原仓库里的测试集来自训练集,这显然不合理)。且由于词表太小,不太好构建测试集。因此,我没有编写从测试集里生成句子并算BLEU score的代码,而是写了一份翻译给定句子的代码。要编写测试BLUE score的代码,只需要把翻译任意句子的代码改个输入,加一个求BLEU score的函数即可。这份翻译任意句子的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch

from dldemos.Transformer.data_load import (load_cn_vocab, load_en_vocab,
idx_to_sentence, maxlen)
from dldemos.Transformer.model import Transformer

# Config
batch_size = 1
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60

PAD_ID = 0


def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()

model = Transformer(len(en2idx), len(cn2idx), 0, d_model, d_ff, n_layers,
heads, dropout_rate, maxlen)
model.to(device)
model.eval()

model_path = 'dldemos/Transformer/model.pth'
model.load_state_dict(torch.load(model_path))

my_input = ['we', "should", "protect", "environment"]
x_batch = torch.LongTensor([[en2idx[x] for x in my_input]]).to(device)

cn_sentence = idx_to_sentence(x_batch[0], idx2en, True)
print(cn_sentence)

y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)


if __name__ == '__main__':
main()

一开始,还是先获取词表,并初始化模型。

1
2
3
4
5
6
7
8
9
10
11
12
def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()

model = Transformer(len(en2idx), len(cn2idx), 0, d_model, d_ff, n_layers,
heads, dropout_rate, maxlen)
model.to(device)
model.eval()

model_path = 'dldemos/Transformer/model.pth'
model.load_state_dict(torch.load(model_path))

之后,我们用自己定义的句子(要做好分词)代替原来的输入x_batch。如果要测试某个数据集,只要把这里x_batch换成测试集里的数据即可。
我们可以顺便把序号数组用idx_to_sentence转回英文,看看序号转换有没有出错。

1
2
3
4
5
my_input = ['we', "should", "protect", "environment"]
x_batch = torch.LongTensor([[en2idx[x] for x in my_input]]).to(device)

cn_sentence = idx_to_sentence(x_batch[0], idx2en, True)
print(cn_sentence)

这段代码会输出we should protect environment。这说明x_batch是我们想要的序号数组。

最后,我们利用Transformer自回归地生成句子,并输出句子。

1
2
3
4
5
6
7
8
9
10
11
y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)

要自回归地生成句子,我们先给句子填入无效字符<pad>,再把第一个字符换成句子开始字符<S>

1
2
3
y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']

之后,我们循环调用Transformer,获取下一个单词的概率分布。我们可以认为,概率最大的那个单词就是模型预测的下一个单词。因此,我们可以用argmax获取预测的下一个单词的序号,填回y_input。这里的y_input和训练时那个y_batch是同一个东西。

1
2
3
4
5
6
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])

最后只要输出生成的句子即可。

1
2
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)

由于训练数据非常少,而且数据都来自新闻,我只好选择了一个比较常见的句子”we should protect environment”作为输入。模型翻译出了一个比较奇怪的结果。

1
<S> 要 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 的 生态 环境 落实 好 环境 </S> 环境 </S> 有效 保护 环境 </S>...

可以看出,模型确实学到了东西,能翻译出“要保护环境”。但是,这翻译的结果也太长太奇怪了。感觉是对训练数据过拟合了。当然,还是那句话,训练集里的数据太少。要提升模型性能并缓解过拟合,加数据集是最好的方法。这个结果起码说明我们Tranformer的编写没有问题。

在生成新句子的时候,我直接拿概率最高的单词当做预测的下一个单词。其实,还有一些更加高级的生成算法,比如Beam Search。如果模型训练得比较好,可以用这些高级一点的算法提高生成句子的质量。

我读了网上几份Transformer实现。这些实现在生成句子算BLEU score时,竟然直接输入测试句子的前n-1个单词,把输出的n-1个单词拼起来,作为模型的翻译结果。这个过程等价于告诉你前i个翻译答案,你去输出第i+1个单词,再把每个结果拼起来。这样写肯定是不合理的。正常来说应该是照着我这样自回归地生成翻译句子。大家参考网上的Transformer代码时要多加留心。

总结

只要读懂了 Transfomer 的论文,用 PyTorch 实现一遍 Transformer 是很轻松的。但是,代码实现中有非常多论文不会提及的细节,你自己实现时很容易踩坑。在这篇文章里,我完整地介绍了一个英中翻译 Transformer 的 PyTorch 实现,相信读者能够跟随这篇文章实现自己的 Transformer,并在代码实现的过程中加深对论文的理解。

再稍微总结一下代码实现中的一些值得注意的地方。代码中最大的难点是 mask 的实现。mask 的处理稍有闪失,就可能会让计算结果中遍布 NaN。一定要想清楚各个模块的 mask 是从哪来的,它们在注意力计算里是怎么被用上的。

另外,有两处地方的实现比较灵活。一处是位置编码的实现,一处是多头注意力中怎么描述“多头”。其他模块的实现都大差不差,千篇一律。

最后再提醒一句,要从头训练一个模型,一定要从小数据集上开始做。不然你训练个半天,结果差了,你不知道是数据有问题,还是代码有问题。我之前一直在使用很大的训练集,每次调试都非常麻烦,浪费了很多时间。希望大家引以为戒。

参考资料

感谢 https://github.com/P3n9W31/transformer-pytorch 提供的数据集。

一份简明的Transformer实现代码 https://github.com/hyunwoongko/transformer

一篇不错的Transformer实现教程 https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec

过期内容

我第一次写这篇文章时过于仓促,文章中有不少错误,实验部分也没写完。我后来把本文又重新修改了一遍,补充了实验部分。

我之前使用了一个较大的数据集,但发现做实验做得很慢,于是换了一个较小的数据集。以前的数据集预处理介绍就挪到这里了。

数据集与评测方法

在开启一个深度学习项目之初,要把任务定义好。准确来说,我们要明白这个任务是在完成一个怎样的映射,并准备一个用于评测的数据集,定义好评价指标。

英中翻译,这个任务非常明确,就是把英文的句子翻译成中文。英中翻译的数据集应该包含若干个句子对,每个句子对由一句英文和它对应的中文翻译组成。

中英翻译的数据集不是很好找。有几个比较出名的数据集的链接已经失效了,还有些数据集需要注册与申请后才能获取。我在中文NLP语料库仓库(https://github.com/brightmart/nlp_chinese_corpus)找到了中英文平行语料 translation2019zh。该语料库由520万对中英文语料构成,训练集516万对,验证集3.9万对。用作训练和验证中英翻译模型是足够了。

机器翻译的评测指标叫做BLEU Score。如果模型输出的翻译和参考译文有越多相同的单词、连续2个相同单词、连续3个相同单词……,则得分越高。

PyTorch 提供了便捷的API,我们可以用一行代码算完BLEU Score。

1
2
3
4
5
>>> from torchtext.data.metrics import bleu_score
>>> candidate_corpus = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']]
>>> references_corpus = [[['My', 'full', 'pytorch', 'test'], ['Completely', 'Different']], [['No', 'Match']]]
>>> bleu_score(candidate_corpus, references_corpus)
0.8408964276313782

数据清洗

得到数据集后,下一步要做的是对数据集做处理,把原始数据转化成能够输入神经网络的张量。对于图片,预处理可能是裁剪、缩放,使所有图片都有一样的大小;对于文本,预处理可能是分词、填充。

网盘上下载好 translation2019zh 数据集后,我们来一步一步清洗这个数据集。这个数据集只有两个文件translation2019zh_train.json, translation2019zh_valid.json,它们的结构如下:

text
1
2
3
4
{"english": <english>, "chinese": <chinese>}
{"english": <english>, "chinese": <chinese>}
{"english": <english>, "chinese": <chinese>}
...

这些json文件有点不合标准,每对句子由一行json格式的记录组成。english属性是英文句子,chinese属性是中文句子。比如:

text
1
{"english": "In Italy ...", "chinese": "在意大利 ..."}

因此,在读取数据时,我们可以用下面的代码提取每对句子。

1
2
3
4
5
6
import json

with open(json_path, 'r') as fp:
for line in fp:
line = json.loads(line)
english, chinese = line['english'], line['chinese']

这个数据集有一点不干净,有一些句子对的中英文句子颠倒过来了。为此,我们要稍微处理一下,把这些句子对翻转过来。如果一个英文句子不全由 ASCII 组成,则它可能是一个被标错的中文句子。

1
2
3
# Correct mislabeled data
if not english.isascii():
english, chinese = chinese, english

经过这一步,我们只得到了中英文的字符文本。而在NLP中,大部分处理的最小单位都是符号(token)——对于英文来说,符号是单词、标点;对于中文来说,符号是词语、标点。我们还需要一个符号化的过程。

英文符号化非常方便,torchtext 提供了非常便捷的英文分词 API。

1
2
3
4
from torchtext.data import get_tokenizer

tokenizer = get_tokenizer('basic_english')
english = tokenizer(english)

而中文分词方面,我使用了jieba库。该库可以直接 pip 安装。

1
pip install jieba

分词的 API 是 jieba.cut。由于分词的结果中,相邻的词之间有空格,我一股脑地把所有空白符给过滤掉了。

1
2
3
import jieba
chinese = list(jieba.cut(chinese))
chinese = [x for x in chinese if x not in {' ', '\t'}]

经过这些处理后,每句话被转换成了中文词语或英文单词的数组。整个处理代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def read_file(json_path):
english_sentences = []
chinese_sentences = []
tokenizer = get_tokenizer('basic_english')
with open(json_path, 'r') as fp:
for line in fp:
line = json.loads(line)
english, chinese = line['english'], line['chinese']
# Correct mislabeled data
if not english.isascii():
english, chinese = chinese, english
# Tokenize
english = tokenizer(english)
chinese = list(jieba.cut(chinese))
chinese = [x for x in chinese if x not in {' ', '\t'}]
english_sentences.append(english)
chinese_sentences.append(chinese)
return english_sentences, chinese_sentences

词语转序号

为了让计算机更方便地处理单词,我们还要把单词转换成序号。比如令apple为0号,banana为1号,则句子apple banana apple就转换成了0 1 0

给每一个单词选一个标号,其实就是要建立一个词典。一般来说,我们可以利用他人的统计结果,挑选最常用的一些英文单词和中文词语构成词典。不过,现在我们已经有了一个庞大的中英语料库了,我们可以直接从这个语料库中挑选出最常见的词构成词典。

根据上一步处理得到的句子数组sentences,我们可以用下面的 Python 代码统计出最常见的一些词语,把它们和4个特殊字符<sos>, <eos>, <unk>, <pad>(句子开始字符、句子结束字符、频率太少没有被加入词典的词语、填充字符)一起构成词典。统计字符出现次数是通过 Python 的 Counter 类实现的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from collections import Counter

def create_vocab(sentences, max_element=None):
"""Note that max_element includes special characters"""

default_list = ['<sos>', '<eos>', '<unk>', '<pad>']

char_set = Counter()
for sentence in sentences:
c_set = Counter(sentence)
char_set.update(c_set)

if max_element is None:
return default_list + list(char_set.keys())
else:
max_element -= 4
words_freq = char_set.most_common(max_element)
# pair array to double array
words, freq = zip(*words_freq)
return default_list + list(words)

准备好了词典后,我还编写了两个工具函数sentence_to_tensortensor_to_sentence,它们可以用于字符串数组与序号数组的互相转换。测试这些代码的脚本及其输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Dataset.py

def main():
en_sens, zh_sens = read_file(
'data/translation2019zh/translation2019zh_valid.json')
print(*en_sens[0:3])
print(*zh_sens[0:3])
en_vocab = create_vocab(en_sens, 10000)
zh_vocab = create_vocab(zh_sens, 30000)
print(list(en_vocab)[0:10])
print(list(zh_vocab)[0:10])

en_tensors = sentence_to_tensor(en_sens, en_vocab)
zh_tensors = sentence_to_tensor(zh_sens, zh_vocab)

print(tensor_to_sentence(en_tensors[0], en_vocab, True))
print(tensor_to_sentence(zh_tensors[0], zh_vocab))
text
1
2
3
4
5
6
['slowly', 'and', 'not', 'without', 'struggle', ',', 'america', 'began', 'to', 'listen', '.'] ...]
['美国', '缓慢', '地', '开始', '倾听', ',', '但', '并非', '没有', '艰难曲折', '。'] ...]
['<sos>', '<eos>', '<unk>', '<pad>', 'the', '.', ',', 'of', 'and', 'to']
['<sos>', '<eos>', '<unk>', '<pad>', '的', ',', '。', '在', '了', '和']
slowly and not without struggle , america began to listen .
美国缓慢地开始倾听,但并非没有<unk>。

在这一步中,有一个重要的参数:词典的大小。显然,词典越大,能处理的词语越多,但训练速度也会越慢。由于这个项目只是一个用于学习的demo,我设置了比较小的词典大小。想提升整个模型的性能的话,调大词典大小是一个最快的方法。

生成 Dataloader

都说程序员是新时代的农民工,这非常有道理。因为,作为程序员,你免不了要写一些繁重、无聊的数据处理脚本。还好,写完这些无聊的预处理代码后,总算可以使用 PyTorch 的 API 写一些有趣的代码了。

把词语数组转换成序号句子数组后,我们要考虑怎么把序号句子数组输入给模型了。文本数据通常长短不一,为了一次性处理一个 batch 的数据,要把短的句子填充,使得一批句子长度相等。写 Dataloader 时最主要的工作就是填充并对齐句子。

先看一下Dataset的写法。上一步得到的序号句子数组可以塞进Dataset里。注意,每个句子的前后要加上表示句子开始和结束的特殊符号。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
SOS_ID = 0
EOS_ID = 1
UNK_ID = 2
PAD_ID = 3

class TranslationDataset(Dataset):

def __init__(self, en_tensor: np.ndarray, zh_tensor: np.ndarray):
super().__init__()
assert len(en_tensor) == len(zh_tensor)
self.length = len(en_tensor)
self.en_tensor = en_tensor
self.zh_tensor = zh_tensor

def __len__(self):
return self.length

def __getitem__(self, index):
x = np.concatenate(([SOS_ID], self.en_tensor[index], [EOS_ID]))
x = torch.from_numpy(x)
y = np.concatenate(([SOS_ID], self.zh_tensor[index], [EOS_ID]))
y = torch.from_numpy(y)
return x, y

接下来看一下 DataLoader 的写法。在创建 Dataloader 时,最重要的是 collate_fn 的编写,这个函数决定了怎么把多条数据合成一个等长的 batch。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_dataloader(en_tensor: np.ndarray,
zh_tensor: np.ndarray,
batch_size=16):

def collate_fn(batch):
...

dataset = TranslationDataset(en_tensor, zh_tensor)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)

return dataloader

collate_fn 的输入是多个 dataset __getitem__ 的返回结果构成的数组。对于我们的 dataset 来说,collate_fn 的输入是 [(x1, y1), (x2, y2), ...] 。我们可以用 zip(*batch) 把二元组数组拆成两个数组 x, y

collate_fn 的输出就是将来 dataloader 的输出。PyTorch 提供了 pad_sequence 函数用来把一批数据填充至等长。

1
2
3
4
5
6
7
8
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
x, y = zip(*batch)
x_pad = pad_sequence(x, batch_first=True, padding_value=PAD_ID)
y_pad = pad_sequence(y, batch_first=True, padding_value=PAD_ID)

return x_pad, y_pad

实现完collate_fn后,我们就可以得到了DataLoader。这样,数据集预处理部分大功告成。

近两年,有许多图像生成类任务的前沿工作都使用了一种叫做”codebook”的机制。追溯起来,codebook机制最早是在VQ-VAE论文中提出的。相比于普通的VAE,VQ-VAE能利用codebook机制把图像编码成离散向量,为图像生成类任务提供了一种新的思路。VQ-VAE的这种建模方法启发了无数的后续工作,包括声名远扬的Stable Diffusion。

在这篇文章中,我将先以易懂的逻辑带领大家一步一步领悟VQ-VAE的核心思想,再介绍VQ-VAE中关键算法的具体形式,最后把VQ-VAE的贡献及其对其他工作的影响做一个总结。通过阅读这篇文章,你不仅能理解VQ-VAE本身的原理,更能知道如何将VQ-VAE中的核心机制活学活用。

从 AE 到 VQ-VAE

为什么VQ-VAE想要把图像编码成离散向量?让我们从最早的自编码器(Autoencoder, AE)开始一步一步谈起。AE是一类能够把图片压缩成较短的向量的神经网络模型,其结构如下图所示。AE包含一个编码器$e()$和一个解码器$d()$。在训练时,输入图像$\mathbf{x}$会被编码成一个较短的向量$\mathbf{z}$,再被解码回另一幅长得差不多的图像$\hat{\mathbf{x}}$。网络的学习目标是让重建出来的图像$\hat{\mathbf{x}}$和原图像$\mathbf{x}$尽可能相似。

解码器可以把一个向量解码成图片。换一个角度看,解码器就是一个图像生成模型,因为它可以根据向量来生成图片。那么,AE可不可以用来做图像生成呢?很可惜,AE的编码器编码出来的向量空间是不规整的。也就是说,解码器只认识经编码器编出来的向量,而不认识其他的向量。如果你把自己随机生成出来的向量输入给解码器,解码器是生成不出有意义的图片的。AE不能够随机生成图片,所以它不能很好地完成图像生成任务,只能起到把图像压缩的作用。

AE离图像生成只差一步了。只要AE的编码空间比较规整,符合某个简单的数学分布(比如最常见的标准正态分布),那我们就可以从这个分布里随机采样向量,再让解码器根据这个向量来完成随机图片生成了。VAE就是这样一种改进版的AE。它用一些巧妙的方法约束了编码向量$\mathbf{z}$,使得$\mathbf{z}$满足标准正态分布。这样,解码器不仅认识编码器编出的向量,还认识其他来自标准正态分布的向量。训练完成后,我们就可以扔掉编码器,用来自标准正态分布的随机向量和解码器来实现随机图像生成了。

VAE的实现细节就不在这里赘述了,是否理解它对理解VQ-VAE没有影响。我们只需知道VAE可以把图片编码成符合标准正态分布的向量即可。让向量符合标准正态分布的原因是方便随机采样。同时,需要强调的是,VAE编码出来的向量是连续向量,也就是向量的每一维都是浮点数。如果把向量的某一维稍微改动0.0001,解码器还是认得这个向量,并且会生成一张和原向量对应图片差不多的图片。

但是,VAE生成出来的图片都不是很好看。VQ-VAE的作者认为,VAE的生成图片之所以质量不高,是因为图片被编码成了连续向量。而实际上,把图片编码成离散向量会更加自然。比如我们想让画家画一个人,我们会说这个是男是女,年龄是偏老还是偏年轻,体型是胖还是壮,而不会说这个人性别是0.5,年龄是0.6,体型是0.7。因此,VQ-VAE会把图片编码成离散向量,如下图所示。

把图像编码成离散向量后,又会带来两个新的问题。第一个问题是,神经网络会默认输入满足一个连续的分布,而不善于处理离散的输入。如果你直接输入0, 1, 2这些数字,神经网络会默认1是一个处于0, 2中间的一种状态。为了解决这一问题,我们可以借鉴NLP中对于离散单词的处理方法。为了处理离散的输入单词,NLP模型的第一层一般都是词嵌入层,它可以把每个输入单词都映射到一个独一无二的连续向量上。这样,每个离散的数字都变成了一个特别的连续向量了。

我们可以把类似的嵌入层加到VQ-VAE的解码器前。这个嵌入层在VQ-VAE里叫做”embedding space(嵌入空间)”,在后续文章中则被称作”codebook”。

离散向量的另一个问题是它不好采样。回忆一下,VAE之所以把图片编码成符合正态分布的连续向量,就是为了能在图像生成时把编码器扔掉,让随机采样出的向量也能通过解码器变成图片。现在倒好,VQ-VAE把图片编码了一个离散向量,这个离散向量构成的空间是不好采样的。VQ-VAE不是面临着和AE一样的问题嘛。

这个问题是无解的。没错!VQ-VAE根本不是一个图像生成模型。它和AE一样,只能很好地完成图像压缩,把图像变成一个短得多的向量,而不支持随机图像生成。VQ-VAE和AE的唯一区别,就是VQ-VAE会编码出离散向量,而AE会编码出连续向量。

可为什么VQ-VAE会被归类到图像生成模型中呢?这是因为VQ-VAE的作者利用VQ-VAE能编码离散向量的特性,使用了一种特别的方法对VQ-VAE的离散编码空间采样。VQ-VAE的作者之前设计了一种图像生成网络,叫做PixelCNN。PixelCNN能拟合一个离散的分布。比如对于图像,PixelCNN能输出某个像素的某个颜色通道取0~255中某个值的概率分布。这不刚好嘛,VQ-VAE也是把图像编码成离散向量。换个更好理解的说法,VQ-VAE能把图像映射成一个「小图像」。我们可以把PixelCNN生成图像的方法搬过来,让PixelCNN学习生成「小图像」。这样,我们就可以用PixelCNN生成离散编码,再利用VQ-VAE的解码器把离散编码变成图像。

让我们来整理一下VQ-VAE的工作过程。

  1. 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成「小图像」,也能把「小图像」变回图像。
  2. 训练PixelCNN,让它学习怎么生成「小图像」。
  3. 随机采样时,先用PixelCNN采样出「小图像」,再用VQ-VAE把「小图像」翻译成最终的生成图像。

到这里,我们已经学完了VQ-VAE的核心思想。让我们来总结一下。VQ-VAE不是一个VAE,而是一个AE。它的目的是把图像压缩成离散向量。或者换个角度说,它提供了把大图像翻译成「小图像」的方法,也提供了把「小图像」翻译成大图像的方法。这样,一个随机生成大图像的问题,就被转换成了一个等价的随机生成一个较小的「图像」的问题。有一些图像生成模型,比如PixelCNN,更适合拟合离散分布。可以用它们来完成生成「小图像」的问题,填补上VQ-VAE生成图片的最后一片空缺。

VQ-VAE 设计细节

在上一节中,我们虽然认识了VQ-VAE的核心思想,但略过了不少实现细节,比如:

  • VQ-VAE的编码器怎么输出离散向量。
  • VQ-VAE怎么优化编码器和解码器。
  • VQ-VAE怎么优化嵌入空间。

在这一节里,我们来详细探究这些细节。

输出离散编码

想让神经网络输出一个整数,最简单的方法是和多分类模型一样,输出一个Softmax过的概率分布。之后,从概率分布里随机采样一个类别,这个类别的序号就是我们想要的整数。比如在下图中,我们想得到一个由3个整数构成的离散编码,就应该让编码器输出3组logit,再经过Softmax与采样,得到3个整数。

但是,这么做不是最高效的。得到离散编码后,下一步我们又要根据嵌入空间把离散编码转回一个向量。可见,获取离散编码这一步有一点多余。能不能把编码器的输出张量(它之前的名字叫logit)、解码器的输入张量embedding、嵌入空间直接关联起来呢?

VQ-VAE使用了如下方式关联编码器的输出与解码器的输入:假设嵌入空间已经训练完毕,对于编码器的每个输出向量$z_e(x)$,找出它在嵌入空间里的最近邻$z_q(x)$,把$z_e(x)$替换成$z_q(x)$作为解码器的输入。

求最近邻,即先计算向量与嵌入空间$K$个向量每个向量的距离,再对距离数组取一个argmin,求出最近的下标(比如图中的0, 1, 1),最后用下标去嵌入空间里取向量。下标构成的数组(比如图中的[0, 1, 1])也正是VQ-VAE的离散编码。

就这样,我们知道了VQ-VAE是怎么生成离散编码的。VQ-VAE的编码器其实不会显式地输出离散编码,而是输出了多个「假嵌入」$z_e(x)$。之后,VQ-VAE对每个$z_e(x)$在嵌入空间里找最近邻,得到真正的嵌入$z_q(x)$,把$z_q(x)$作为解码器的输入。

虽然我们现在能把编码器和解码器拼接到一起,但现在又多出了一个问题:怎么让梯度从解码器的输入$z_q(x)$传到$z_e(x)$?从$z_e(x)$到$z_q(x)$的变换是一个从数组里取值的操作,这个操作是求不了导的。我们在下一小节里来详细探究一下怎么优化VQ-VAE的编码器和解码器。

优化编码器和解码器

为了优化编码器和解码器,我们先来制订一下VQ-VAE的整体优化目标。由于VQ-VAE其实是一个AE,误差函数里应该只有原图像和目标图像的重建误差。

或者非要从VAE的角度说也行。VQ-VAE相当于输出了一个one-hot离散分布。假设输入图像$x$的离散编码$z$是$k$,则分布中仅有$q(z=k|x)=1$,$q(z=others|x)=0$。令离散编码$z$的先验分布是均匀分布(假设不知道输入图像$x$,每个离散编码取到的概率是等同的),则先验分布$q(z)$和后验分布$q(z|x)$的KL散度是常量。因此,KL散度项不用算入损失函数里。理解此处的数学推导意义不大,还不如直接理解成VQ-VAE其实是一个AE。

但直接拿这个误差来训练是不行的。误差中,$z_q(x)$是解码器的输入。从编码器输出$z_e(x)$到$z_q(x)$这一步是不可导的,误差无法从解码器传递到编码器上。要是可以把$z_q(x)$的梯度直接原封不动地复制到$z_e(x)$上就好了。

VQ-VAE使用了一种叫做”straight-through estimator”的技术来完成梯度复制。这种技术是说,前向传播和反向传播的计算可以不对应。你可以为一个运算随意设计求梯度的方法。基于这一技术,VQ-VAE使用了一种叫做$sg$(stop gradient,停止梯度)的运算:

也就是说,前向传播时,$sg$里的值不变;反向传播时,$sg$按值为0求导,即此次计算无梯度。(反向传播其实不会用到式子的值,只会用到式子的梯度。反向传播用到的loss值是在前向传播中算的)。

基于这种运算,我们可以设计一个把梯度从$z_e(x)$复制到$z_q(x)$的误差:

也就是说,前向传播时,就是拿解码器输入$z_q(x)$来算梯度。

而反向传播时,按下面这个公式求梯度,等价于把解码器的梯度全部传给$z_e(x)$。

这部分的PyTorch实现如下所示。在PyTorch里,(x).detach()就是$sg(x)$,它的值在前向传播时取x,反向传播时取0

1
L = x - decoder(z_e + (z_q - z_e).detach())

通过这一技巧,我们完成了梯度的传递,可以正常地训练编码器和解码器了。

优化嵌入空间

到目前为止,我们的讨论都是建立在嵌入空间已经训练完毕的前提上的。现在,我们来讨论一下嵌入空间的训练方法。

嵌入空间的优化目标是什么呢?嵌入空间的每一个向量应该能概括一类编码器输出的向量,比如一个表示「青年」的向量应该能概括所有14-35岁的人的照片的编码器输出。因此,嵌入空间的向量应该和其对应编码器输出尽可能接近。如下面的公式所示,$z_e(x)$是编码器的输出向量,$z_q(x)$是其在嵌入空间的最近邻向量。

但作者认为,编码器和嵌入向量的学习速度应该不一样快。于是,他们再次使用了停止梯度的技巧,把上面那个误差函数拆成了两部分。其中,$\beta$控制了编码器的相对学习速度。作者发现,算法对$\beta$的变化不敏感,$\beta$取0.1~2.0都差不多。

其实,在论文中,作者分别讨论了上面公式里的两个误差。第一个误差来自字典学习算法里的经典算法Vector Quantisation(VQ),也就是VQ-VAE里的那个VQ,它用于优化嵌入空间。第二个误差叫做专注误差,它用于约束编码器的输出,不让它跑到离嵌入空间里的向量太远的地方。

这样,VQ-VAE总体的损失函数可以写成:(由于算上了重建误差,我们多加一个$\alpha$用于控制不同误差之间的比例)

总结

VQ-VAE是一个把图像编码成离散向量的图像压缩模型。为了让神经网络理解离散编码,VQ-VAE借鉴了NLP的思想,让每个离散编码值对应一个嵌入,所有的嵌入都存储在一个嵌入空间(又称”codebook”)里。这样,VQ-VAE编码器的输出是若干个「假嵌入」,「假嵌入」会被替换成嵌入空间里最近的真嵌入,输入进解码器里。

VQ-VAE的优化目标由两部分组成:重建误差和嵌入空间误差。重建误差为输入图片和重建图片的均方误差。为了让梯度从解码器传到编码器,作者使用了一种巧妙的停止梯度算子,让正向传播和反向传播按照不同的方式计算。嵌入空间误差为嵌入和其对应的编码器输出的均方误差。为了让嵌入和编码器以不同的速度优化,作者再次使用了停止梯度算子,把嵌入的更新和编码器的更新分开计算。

训练完成后,为了实现随机图像生成,需要对VQ-VAE的离散分布采样,再把采样出来的离散向量对应的嵌入输入进解码器。VQ-VAE论文使用了PixelCNN来采样离散分布。实际上,PixelCNN不是唯一一种可用的拟合离散分布的模型。我们可以把它换成Transformer,甚至是diffusion模型。如果你当年看完VQ-VAE后立刻把PixelCNN换成了diffusion模型,那么恭喜你,你差不多提前设计出了Stable Diffusion。

可见,VQ-VAE最大的贡献是提供了一种图像压缩思路,把生成大图像的问题转换成了一个更简单的生成「小图像」的问题。图像压缩成离散向量时主要借助了嵌入空间,或者说”codebook”这一工具。这种解决问题的思路可以应用到所有图像生成类任务上,比如超分辨率、图像修复、图像去模糊等。所以近两年我们能看到很多使用了codebook的图像生成类工作。

参考资料

PixelCNN的介绍可以参见我之前的文章:详解PixelCNN大家族。

VQ-VAE的论文为Neural Discrete Representation Learning。这篇文章不是很好读懂,建议直接读我的这篇解读。再推荐另一份还不错的中文解读 https://www.spaces.ac.cn/archives/6760。

图像生成是一个较难建模的任务。为此,我们要用GAN、VAE、Diffusion等精巧的架构来建模图像生成。可是,在NLP中,文本生成却有一种非常简单的实现方法。NLP中有一种基础的概率模型——N元语言模型。N元语言模型可以根据句子的前几个字预测出下一个字的出现概率。比如看到「我爱吃苹……」这句话的前几个字,我们不难猜出下一个字大概率是「果」字。利用N元语言模型,我们可以轻松地实现一个文本生成算法:输入空句子,采样出第一个字;输入第一个字,采样出第二个字;输入前两个字,输出第三个字……以此类推。

既然如此,我们可不可以把相同的方法搬到图像生成里呢?当然可以。虽然图像是二维的数据,不像一维的文本一样有先后顺序,但是我们可以强行给图像的每个像素规定一个顺序。比如,我们可以从左到右,从上到下地给图像标上序号。这样,从逻辑上看,图像也是一个一维数据,可以用NLP中的方法来按照序号实现图像生成了。

PixelCNN就是一个使用这种方法生成图像的模型。可为什么PixelCNN的名气没有GAN、VAE那么大?为什么PixelCNN可以用CNN而不是RNN来处理一维化图像?为什么PixelCNN是一种「自回归模型」?别急,在这篇文章中,我们将认识PixelCNN及其改进模型Gated PixelCNN和PixelCNN++,并认真学习它们的实现代码。看完文章后,这些问题都会迎刃而解。

PixelCNN

如前所述,PixelCNN借用了NLP里的方法来生成图像。模型会根据前i - 1个像素输出第i个像素的概率分布。训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数;采样时,直接从预测的概率分布里采样出第i个像素。根据这些线索,我们来尝试自己「发明」一遍PixelCNN。

这种模型最朴素的实现方法,是输入一幅图像的前i - 1个像素,输出第i个像素的概率分布,即第i个像素取某种颜色的概率的数组。为了方便讨论,我们先只考虑单通道图像,每个像素的颜色取值只有256种。因此,准确来说,模型的输出是256个经过softmax的概率。这样,我们得到了一个V1.0版本的模型。

等等,模型不是叫「PixelCNN」吗?CNN跑哪去了?的确,对于图像数据,最好还是使用CNN,快捷又有效。因此,我们应该修改模型,令模型的输入为整幅图像和序号i。我们根据序号i,过滤掉ii之后的像素,用CNN处理图像。输出部分还是保持一致。

V2.0并不是最终版本,我们可以暂时不用考虑实现细节,比如这里的「过滤」是怎么实现的。硬要做的话,这种过滤也可以暴力实现:把无效像素初始化为0,每次卷积后再把无效像素置0。

改进之后,V2.0版本的模型确实能快速计算第i个像素的概率分布了。可是,CNN是很擅长同时生成一个和原图像长宽相同的张量的,只算一个像素的概率分布还称不上高效。所以,我们可以让模型输入一幅图像,同时输出图像每一处的概率分布。

这次的改进并不能加速采样。但是,在训练时,由于整幅训练图像已知,我们可以在一次前向传播后得到图像每一处的概率分布。假设图像有N个像素,我们就等于是在并行地训练N个样本,训练速度快了N倍!

这种并行训练的想法和Transformer如出一辙。

V3.0版本的PixelCNN已经和论文里的PixelCNN非常接近了,我们来探讨一下网络的实现细节。相比普通的CNN,PixelCNN有一个特别的约束:第i个像素只能看到前i-1个像素的信息,不能看到第i个像素及后续像素的信息。对于V2.0版本只要输出一个概率分布的PixelCNN,我们可以通过一些简单处理过滤掉第i个像素之后的信息。而对于并行输出所有概率分布的V3.0版本,让每个像素都忽略后续像素的信息的方法就不是那么显然了。

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。具体来说,PixelCNN使用了两类掩码卷积,我们把两类掩码卷积分别称为「A类」和「B类」。二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。A类和B类的唯一区别在于卷积核的中心像素是否产生贡献。CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积。如下图所示。

为什么要先用一次A类掩码卷积,再每次使用B类掩码卷积呢?我们不妨来做一个实验。对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)。

不难看出,经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。再经过两次B类卷积,中心像素能够看到左上角大部分像素的信息。这满足PixelCNN的约束。

而如果一直使用A类卷积,每次卷积后中心像素都会看漏一些信息(不妨对比下面这张示意图和上面那张示意图)。多卷几层后,中心像素的值就会和输入图像毫无关系。

只是用B类卷积也是不行的。显然,如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。这下,我们能明白为什么只能先用一次A类卷积,再用若干次B类卷积了。

利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。除此之外,PixelCNN就没有什么特别的地方了。我们可以用任意一种CNN架构来实现PixelCNN。PixelCNN论文使用了一种类似于ResNet的架构。其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

到目前为止,我们已经成功搭建了处理单通道图像的PixelCNN。现在,我们来尝试把它推广到多通道图像上。相比于单通道图像,多通道图像只不过是一个像素由多个颜色分量组成。我们可以把一个像素的颜色分量看成是子像素。在定义约束关系时,我们规定一个子像素只由它之前的子像素决定。比如对于RGB图像,R子像素由它之前所有像素决定,G子像素由它的R子像素和之前所有像素决定,B子像素由它的R、G子像素和它之前所有像素决定。生成图像时,我们一个子像素一个子像素地生成。

把我们的PixelCNN V3.0推广到RGB图像时,我们要做的第一件事就是修改网络的通道数量。由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量,即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

这里说网络中间的通道数要乘3只是一种方便理解的说法。实际上,中间的通道数可以随意设置,是不是3的倍数都无所谓,只是所有通道在逻辑上被分成了3组。我们稍后会利用到「中间结果的通道数应该能被拆成3组」这一性质。

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  1. 对于通道间的约束,我们要在o, i两个维度上设置掩码。设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3,即o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:oi1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i。序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。我们对输入通道组和输出通道组之间进行约束。对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3;对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3

  2. 对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

就这样,修改了通道数,修改了卷积核的掩码后,我们成功实现了论文里的PixelCNN。让我们把这个过程总结一下。PixelCNN的核心思想是给图像的子像素定义一个先后顺序,之后让每个子像素的颜色取值分布由之前所有的子像素决定。实现PixelCNN时,可以用任意一种CNN架构,并注意两点:

  1. 网络的输出是一个经softmax的概率分布。
  2. 网络的所有卷积层要替换成带掩码的卷积层,第一个卷积层用A类掩码,后面的用B类掩码。

学完了PixelCNN,我们在闲暇之余来谈一谈PixelCNN和其他生成网络的对比情况。精通数学的人,会把图像生成问题看成学习一个图像的分布。每次生成一张图片,就是在图像分布里随机采样一个张量。学习一个分布,最便捷的方法是定义一个带参数$\theta$的概率模型$P_\theta$,最大化来自数据集的图像$\mathbf{x}$的概率$P_\theta(\mathbf{x})$。

可问题来了:一个又方便采样,又能计算概率的模型不好设计。VAE和Diffusion建模了把一个来自正态分布的向量$\mathbf{z}$变化成$\mathbf{x}$的过程,并使用了统计学里的变分推理,求出了$P_\theta(\mathbf{x})$的一个下界,再设法优化这个下界。GAN干脆放弃了概率模型,直接拿一个神经网络来评价生成的图像好不好。

PixelCNN则正面挑战了建立概率模型这一任务。它把$P_\theta(\mathbf{x})$定义为每个子像素出现概率的乘积,而每个子像素的概率仅由它之前的子像素决定。

由于我们可以轻松地用神经网络建模每个子像素的概率分布并完成采样,PixelCNN的采样也是很方便的。我们可以说PixelCNN是一个既方便采样,又能快速地求出图像概率的模型。

相比与其他生成模型,PixelCNN直接对$P_\theta(\mathbf{x})$建模,在和概率相关的指标上表现优秀。很可惜,能最大化数据集的图像的出现概率,并不代表图像的生成质量就很优秀。因此,一直以来,以PixelCNN为代表的对概率直接建模的生成模型没有受到过多的关注。可能只有少数必须要计算图像概率分布的任务才会用到PixelCNN。

除了能直接计算图像的概率外,PixelCNN还有一大特点:PixelCNN能输出离散的颜色值。VAE和GAN这些模型都是把图像的颜色看成一个连续的浮点数,模型的输入和输出的取值范围都位于-1到1之间(有些模型是0到1之间)。而PixelCNN则输出的是像素取某个颜色的概率分布,它能描述的颜色是有限而确定的。假如我们是在生成8位单通道图像,那网络就只输出256个离散的概率分布。能生成离散输出这一特性启发了后续很多生成模型。另外,这一特性也允许我们指定颜色的亮度级别。比如对于黑白手写数字数据集MNIST,我们完全可以用黑、白两种颜色来描述图像,而不是非得用256个灰度级来描述图像。减少亮度级别后,网络的训练速度能快上很多。

在后续的文献中,PixelCNN被归类为了自回归生成模型。这是因为PixelCNN在生成图像时,要先输入空图像,得到第一个像素;把第一个像素填入空图像,输入进模型,得到第二个像素……。也就是说,一个图像被不断扔进模型,不断把上一时刻的输出做为输入。这种用自己之前时刻的状态预测下一个状态的模型,在统计学里被称为自回归模型。如果你在其他图像生成文献中见到了「自回归模型」这个词,它大概率指的就是PixelCNN这种每次生成一个像素,该像素由之前所有像素决定的生成模型。

Gated PixelCNN

首篇提出PixelCNN的论文叫做Pixel Recurrent Neural Networks。没错!这篇文章的作者提出了一种叫做PixelRNN的架构,PixelCNN只是PixelRNN的一个变种。可能作者一开始也没指望PixelCNN有多强。后来,人们发现PixelCNN的想法还挺有趣的,但是原始的PixelCNN设计得太烂了,于是开始着手改进原始的PixelCNN。

PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,在我们刚刚的实验中,中心像素看不到右上角三个本应该能看到的像素。哪怕你对用B类卷积多卷几次,右上角的视野盲区都不会消失。

为此,PixelCNN论文的作者们又打了一些补丁,发表了Conditional Image Generation with PixelCNN Decoders这篇论文。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。如下图所示,Gated PixelCNN使用了两种卷积——垂直卷积和水平卷积——来分别维护一个像素上侧的信息和左侧的信息。垂直卷积的结果只是一些临时量,而水平卷积的结果最终会被网络输出。可以看出,使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

除此之外,Gated PixelCNN还把网络中的激活函数从ReLU换成了LSTM的门结构。Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。此外,模块右侧那部分还有一个残差连接。

除了上面的两项改动,Gated PixelCNN还做出了其他的一些改动。比如,Gated PixelCNN支持带约束的图像生成,比如根据文字生成图片、根据类别生成图片。用于约束生成的向量$\mathbf{h}$会被输入进网络每一层的激活函数中。当然,这些改动不是为了提升原PixelCNN的性能。

PixelCNN++

之后,VAE的作者也下场了,提出了一种改进版的PixelCNN,叫做PixelCNN++。这篇论文没有多余的废话,在摘要中就简明地指出了PixelCNN++的几项改动:

  1. 使用logistic分布代替256路softmax
  2. 简化RGB子像素之间的约束关系
  3. 使用U-Net架构
  4. 使用dropout正则化

这几项改动中,第一项改动是最具启发性的,这一技巧可以拓展到其他任务上。让我们主要学习一下第一项改动,并稍微浏览一下其他几项改动。

离散logistic混合似然

原PixelCNN使用256路softmax建模一个像素的颜色概率分布。这么做确实能让模型更加灵活,但有若干缺点。首先,计算这么多的概率值会占很多内存;其次,由于每次训练只有一个位置的标签为1,其他255个位置的标签都是0,模型可学习参数的梯度会很稀疏;最后,在这种概率分布方式下,256种颜色是分开考虑的,这导致模型并不知道相邻的颜色比较相似(比如颜色值128和127、129比较相似)这一事实。总之,用softmax独立地表示各种颜色有着诸多的不足。

作者把颜色的概率分布建模成了连续分布,一下子克服掉了上述所有难题。让我们来仔细看一下新概率分布的定义方法。

首先,新概率分布使用到的连续分布叫做logistic分布。它有两个参数:均值$\mu$和方差$s^2$。它的概率密度函数为:

logistic分布的概率密度函数看起来比较复杂。但是,如果对这个函数积分,得到的累计分布函数就是logistic函数。如果令均值为0,方差为1,则logistic函数就是我们熟悉的sigmoid函数了。

接着,每个分布可能是$K$个参数不同的logistic分布中的某一个,选择某个logistic分布的概率由$\pi_i$表示。比如$K=2$,$\pi_1 = 0.3, \pi_2=0.7$,就说明有两个可选的logisti分布,每个分布有30%的概率会使用1号logistic分布,有70%的概率会使用2号logistic分布。 这里的$\pi_i$和原来256路softmax的输出的意义一样,都是选择某个东西的概率。当然,$K$会比256要小很多,不然这种改进就起不到减小计算量的作用了。设一个输出颜色为$v$,它的数学表达式为:

可logsitc分布是一个连续分布,而我们想得到256个颜色中某个颜色的概率,即得到一个离散的分布。因此,在最后一步,我们要从上面这个连续分布里得到一个离散的分布。我们先不管$K$和$\pi_i$,只考虑有一个logistic分布的情况。根据统计学知识可知,要从连续分布里得到一个离散分布,可以把定义域拆成若干个区间,对每个区间的概率求积分。在我们的例子里,我们可以把实数集拆成256个区间,令$(-\infty, 0.5]$为第1个区间,$(0.5, 1.5]$为第2个区间,……,$(253.5, 254.5]$为第255个区间, $(254.5, +\infty)$为第256个区间。

对概率密度函数求积分,就是在累积分布函数上做差。因此,对于某个离散颜色值$x\in[0, 255], x\in \mathbb{N}$,已知一个logistic分布$logistic(\mu, s)$,则这个颜色值的出现概率是:

其中,$\sigma()$是sigmoid函数。$\sigma((x-\mu)/s)$就是分布的累积分布函数。

可以看出,使用这种区间划分方法,位于0处和位于255处的颜色的概率相对会高一点。这一特点符合作者统计出的CIFAR-10里的颜色分布规律。

当有$K$个logistic分布时,只要把各个分布的概率做一个加权和就行(公式省略掉了$x$位于边界处的情况)。

至此,我们已经知道了怎么用一个「离散logistic混合似然」来建模颜色的概率分布了。这个更高级的颜色分布以logistic分布为基础,以比例(概率)$\pi_i$混合了$K$个logstic分布,并用巧妙的方法把连续分布转换成了离散分布。

简化RGB子像素之间的约束关系

在原PixelCNN中,生成一个像素的RGB三个子像素时,为了保证子像素之间的约束,我们要把模型中所有特征图的通道分成三组,并用掩码来维持三组通道间的约束。这样做太麻烦了。因此,PixelCNN++对约束做了一定的简化:根据之前所有像素,网络一次性输出三个子像素的均值和方差,而不用掩码区分三个子像素的信息。当然,只是这样做是不够好的——G子像素缺少了R子像素的信息,B子像素缺少了R、G子像素的信息。为了弥补信息的缺失,PixelCNN会为每个像素额外输出三个参数$\alpha, \beta, \gamma$,$\alpha$描述R对G子像素的约束关系,$\beta$描述R对B的约束关系,$\gamma$描述G对B的约束关系。

让我们来用公式更清晰地描述这一过程。对于某个像素的第$i$个logistic分布,网络会输出10个参数:$\pi, \mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$。$\pi$就是之前见过的选择第$i$个分布的概率,$\mu_r, \mu_g, \mu_b$是网络输出的三个子像素的均值,$s_r, s_g, s_b$是网络输出的三个子像素的标准差,$\alpha, \beta, \gamma$描述子像素之间的约束。

由于缺少了其他子像素的信息,网络直接输出的$\mu_g, \mu_b$是不准的。我们假设子像素之间仅存在简单的线性关系。这样,可以用下面的公式更新$\mu_g$和$\mu_b$:

更新后的$\mu_g$和$\mu_b$才是训练和采样时使用的最终均值。

你会不会疑惑上面那个公式里的$r$和$g$是哪里来的?别忘了,虽然子像素之间的约束被简化了,但是三个子像素还是按先后顺序生成的。在训练时,我们是知道所有子像素的真值的,公式里的$r$和$g$来自真值;而在采样时,我们会先用神经网络生成三个子像素的均值和方差,再采样$r$,把采样的$r$套入公式采样出$g$,最后把采样的$r,g$套入公式采样出$b$.

使用U-Net架构

PixelCNN++的网络架构是一个三级U-Net,即网络先下采样两次再上采样两次,同级编码器(下采样部分)的输出会连到解码器(上采样部分)的输入上。这个U-Net和其他任务中的U-Net没什么太大的区别。

使用Dropout

过拟合会导致生成图像的观感不好。为此,PixelCNN++采用了Dropout正则化方法,在每个子模块的第一个卷积后面加了一个Dropout。

除了这些改动外,PixelCNN++还使用了类似于Gated PixelCNN里垂直卷积和水平卷积的设计,以消除原PixelCNN里的视野盲区。当然,这点不算做本文的主要贡献。

总结

PixelCNN把文本生成的想法套入了图像生成中,令子像素的生成有一个先后顺序。为了在维护先后顺序的同时执行并行训练,PixelCNN使用了掩码卷积。这种并行训练与掩码的设计和Transformer一模一样。如果你理解了Transformer,就能一眼看懂PixelCNN的原理。

相比与其他的图像生成模型,以PixelCNN为代表的自回归模型在生成效果上并不优秀。但是,PixelCNN有两个特点:能准确计算某图像在模型里的出现概率(准确来说在统计学里叫做「似然」)、能生成离散的颜色输出。这些特性为后续诸多工作铺平了道路。

原版的PixelCNN有很多缺陷,后续很多工作对其进行了改进。Gated PixelCNN主要消除了原PixelCNN里的视野盲区,而PixelCNN++提出了一种泛用性很强的用连续分布建模离散颜色值的方法,并用简单的线性约束代替了原先较为复杂的用神经网络表示的子像素之间的约束。

PixelCNN相关的知识难度不高,了解本文介绍的内容足矣。PixelCNN也不是很常见的架构,复现代码的优先级不高,有时间的话阅读一下本文附录中的代码即可。另外,PixelCNN的代码实现里有一个重要的知识点。这个知识点几乎不会在论文和网上的文章里看到,但它对实现是否成功有着重要的影响。如果你对新知识感兴趣,推荐去读一下附录中对其的介绍。

参考资料与学习提示

Pixel Recurrent Neural Networks 是提出了PixelCNN的文章。当然,这篇文章主要是在讲PixelRNN,只想学PixelCNN的话通读这篇文章的价值不大。

Conditional Image Generation with PixelCNN Decoders 是提出Gated PixelCNN的文章。可以主要阅读消除视野盲区和门激活函数的部分。

PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications 是提出PixelCNN++的文章。整篇文章非常简练,可以整体阅读一遍,并且着重阅读离散logistic混合似然的部分。不过,这篇文章有很多地方写得过于简单了,连公式里的字母都不好好交代清楚,我还是看代码才看懂他们想讲什么。建议搭配本文的讲解阅读。

这几篇文章都使用了NLL(负对数似然)这个评价指标。实际上,这个指标就是对所有数据在模型里的平均出现概率取了个对数,加了个负号。对于PixelCNN,其NLL就是交叉熵损失函数。其他生成模型不是直接对数据的概率分布建模,它们的NLL不好求得。比如diffusion模型只能计算NLL的一个上界。

网上还有几份PyTorch代码复现供参考:

PixelCNN:https://github.com/singh-hrituraj/PixelCNN-Pytorch

Gated PixelCNN:https://github.com/anordertoreclaim/PixelCNN

附录:代码学习

在附录中,我将给出PixelCNN和Gated PixelCNN的PyTorch实现,并讲解PixelCNN++开源代码的实现细节。

PixelCNN 与 GatedPixelCNN

为了简化实现,我们来实现MNIST上的PixelCNN和Gated PixelCNN。MNIST是单通道数据集,我们不用考虑颜色通道之间复杂的约束。代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/pixelcnn。

我们先准备好数据集。PyTorch的torchvision提供了获取了MNIST的接口,我们只需要用下面的函数就可以生成MNIST的Dataset实例。参数中,root为数据集的下载路径,download为是否自动下载数据集。令download=True的话,第一次调用该函数时会自动下载数据集,而第二次之后就不用下载了,函数会读取存储在root里的数据。

1
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)

我们可以用下面的代码来下载MNIST并输出该数据集的一些信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torchvision
from torchvision.transforms import ToTensor
def download_dataset():
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp.jpg')
tensor = ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

if __name__ == '__main__':
import os
os.makedirs('work_dirs', exist_ok=True)
download_dataset()

执行这段代码,输出大致为:

1
2
3
4
5
6
length of MNIST 60000
<PIL.Image.Image image mode=L size=28x28 at 0x7FB3F09CCE50>
9
torch.Size([1, 28, 28])
tensor(1.)
tensor(0.)

第一行输出表明,MNIST数据集里有60000张图片。而从第二行和第三行输出中,我们发现每一项数据由图片和标签组成,图片是大小为28x28的PIL格式的图片,标签表明该图片是哪个数字。我们可以用torchvision里的ToTensor()把PIL图片转成PyTorch张量,进一步查看图片的信息。最后三行输出表明,每一张图片都是单通道图片(灰度图),颜色值的取值范围是0~1。

我们可以查看一下每张图片的样子。如果你是在用带显示器的电脑,可以去掉img.show那一行的注释,直接查看图片;如果你是在用服务器,可以去img.save的路径里查看图片。该图片的应该长这个样子:

我们可以用下面的代码预处理数据并创建DataLoader。PixelCNN对输入图片的颜色取值没有特别的要求,我们可以不对图片的颜色取值做处理,保持取值范围在0~1即可。

1
2
3
4
5
6
from torch.utils.data import DataLoader

def get_dataloader(batch_size: int):
dataset = torchvision.datasets.MNIST(root='./data/mnist',
transform=ToTensor())
return DataLoader(dataset, batch_size=batch_size, shuffle=True)

准备好数据后,我们来实现PixelCNN和Gated PixelCNN。先从PixelCNN开始。

实现PixelCNN,最重要的是实现掩码卷积。其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2] = 1
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。因此,掩码卷积类里需要实现一个普通卷积的操作。实现普通卷积,既可以写成继承nn.Conv2d,也可以把nn.Conv2d的实例当成成员变量。这份代码使用了后一种实现方法。在__init__里把其他参数原封不动地传给self.conv,并在forward中直接调用self.conv(x)

1
2
3
4
5
6
7
8
9
10
11
12
class MaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
...
self.conv = nn.Conv2d(*args, **kwags)
...

def forward(self, x):
...
conv_res = self.conv(x)
return conv_res

准备好卷积对象后,我们来维护掩码张量。由于输入输出都是单通道图像,按照正文中关于PixelCNN的描述,我们只需要在卷积核的h, w两个维度设置掩码。我们可以用下面的代码生成一个形状为(H, W)的掩码并根据卷积类型对掩码赋值:

1
2
3
4
5
6
7
8
9
10
def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
...
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2] = 1
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1

然后,为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作。

1
mask = mask.reshape((1, 1, H, W))

在初始化函数的最后,我们把用PyTorch API把mask注册成名为'mask'的成员变量。register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中。这样做的好处时,每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。否则的话我们要手动model.mask = model.mask.to(device)转设备。register_buffer的第三个参数表示被注册的变量是否要加入state_dict中以保存下来。由于这里mask每次都会自动生成,我们不需要把它存下来,可以令第三个参数为False

1
self.register_buffer('mask', mask, False)

在前向传播时,只需要先让卷积核乘掩码,再做普通的卷积。

1
2
3
4
def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来。

我们先照着论文实现残差块ResidualBlock。原论文并没有使用归一化,但我发现使用归一化后效果会好一点,于是往模块里加了BatchNorm。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class ResidualBlock(nn.Module):

def __init__(self, h, bn=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(2 * h, h, 1)
self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
self.conv3 = nn.Conv2d(h, 2 * h, 1)
self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

def forward(self, x):
y = self.relu(x)
y = self.conv1(y)
y = self.bn1(y)
y = self.relu(y)
y = self.conv2(y)
y = self.bn2(y)
y = self.relu(y)
y = self.conv3(y)
y = self.bn3(y)
y = y + x
return y

有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class PixelCNN(nn.Module):

def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
super().__init__()
self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
self.residual_blocks = nn.ModuleList()
for _ in range(n_blocks):
self.residual_blocks.append(ResidualBlock(h, bn))
self.relu = nn.ReLU()
self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
self.out = nn.Conv2d(linear_dim, color_level, 1)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
for block in self.residual_blocks:
x = block(x)
x = self.relu(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.out(x)
return x

PixelCNN实现完毕,我们来按照同样的流程实现Gated PixelCNN。首先,我们要实现其中的垂直掩码卷积和水平掩码卷积,二者的实现和PixelCNN里的掩码卷积差不多,只是mask的内容不太一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class VerticalMaskConv2d(nn.Module):

def __init__(self, *args, **kwags):
super().__init__()
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2 + 1] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res


class HorizontalMaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

水平卷积其实只要用一个1x3的卷积就可以实现了。但出于偷懒(也为了方便理解),我还是在3x3卷积的基础上添加的mask

之后我们来用两种卷积搭建论文中的Gated Block。

Gated Block搭起来稍有难度。如上面的结构图所示,我们主要要维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。v会经过一个垂直掩码卷积和一个门激活函数。h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class GatedBlock(nn.Module):

def __init__(self, conv_type, in_channels, p, bn=True):
super().__init__()
self.conv_type = conv_type
self.p = p
self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)
self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
1)
self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_output_conv = nn.Conv2d(p, p, 1)
self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()

def forward(self, v_input, h_input):
v = self.v_conv(v_input)
v = self.bn1(v)
v_to_h = v[:, :, 0:-1]
v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
v_to_h = self.v_to_h_conv(v_to_h)
v_to_h = self.bn2(v_to_h)

v1, v2 = v[:, :self.p], v[:, self.p:]
v1 = torch.tanh(v1)
v2 = torch.sigmoid(v2)
v = v1 * v2

h = self.h_conv(h_input)
h = self.bn3(h)
h = h + v_to_h
h1, h2 = h[:, :self.p], h[:, self.p:]
h1 = torch.tanh(h1)
h2 = torch.sigmoid(h2)
h = h1 * h2
h = self.h_output_conv(h)
h = self.bn4(h)
if self.conv_type == 'B':
h = h + h_input
return v, h

代码中的其他地方都比较常规,唯一要注意的是vh的合成部分。这一部分的实现初看下来比较难懂。为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位,而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)。

1
2
v_to_h = v[:, :, 0:-1]
v_to_h = F.pad(v_to_h, (0, 0, 1, 0))

为什么实际上是要对特征图v下移一个单位呢?实际上,在拼接vh时,我们是想做下面这个计算:

1
2
3
for i in range(H):
for j in range(W):
h[:, :, i, j] += v[:, :, i - 1, j]

但是,写成循环就太慢了,我们最好是能做向量化计算。注意到,vi相加的位置只差了一个单位。为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。

除了vh的合成有点麻烦外,GatedBlock还有一个细节值得注意。h的计算通路中有一个残差连接,但是,在网络的第一层,每个数据是不能看到自己的。所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。

最后,我们来用GatedBlock搭出Gated PixelCNN。Gated PixelCNN和原版PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GatedPixelCNN(nn.Module):

def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__()
self.block1 = GatedBlock('A', 1, p, bn)
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(GatedBlock('B', p, p, bn))
self.relu = nn.ReLU()
self.linear1 = nn.Conv2d(p, linear_dim, 1)
self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
self.out = nn.Conv2d(linear_dim, color_level, 1)

def forward(self, x):
v, h = self.block1(x, x)
for block in self.blocks:
v, h = block(v, h)
x = self.relu(h)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.out(x)
return x

准备好了模型代码,我们可以编写训练和采样的脚本了。我们先用超参数初始化好两个模型。根据论文的描述,PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from dldemos.pixelcnn.dataset import get_dataloader, get_img_shape
from dldemos.pixelcnn.model import PixelCNN, GatedPixelCNN

import torch
import torch.nn as nn
import torch.nn.functional as F

import time
import einops
import cv2

import numpy as np
import os

batch_size = 128
color_level = 8 # or 256

models = [
PixelCNN(15, 128, 32, True, color_level),
GatedPixelCNN(15, 128, 32, True, color_level)
]

if __name__ == '__main__':
os.makedirs('work_dirs', exist_ok=True)
model_id = 1
model = models[model_id]
device = 'cuda'
model_path = f'dldemos/pixelcnn/model_{model_id}_{color_level}.pth'
train(model, device, model_path)
sample(model, device, model_path,
f'work_dirs/pixelcnn_{model_id}_{color_level}.jpg')

之后是训练部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def train(model, device, model_path):
dataloader = get_dataloader(batch_size)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
loss_fn = nn.CrossEntropyLoss()
n_epochs = 40
tic = time.time()
for e in range(n_epochs):
total_loss = 0
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
y = torch.ceil(x * (color_level - 1)).long()
y = y.squeeze(1)
predict_y = model(x)
loss = loss_fn(predict_y, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), model_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')

这部分代码十分常规,和普通的多分类任务十分类似。代码中值得一看的是下面几行:

1
2
3
4
y = torch.ceil(x * (color_level - 1)).long()
y = y.squeeze(1)
predict_y = model(x)
loss = loss_fn(predict_y, y)

这几行代码根据输入x得到了标签y,再做前向传播,最后用预测的predict_yy求交叉熵损失函数。这里第一个要注意的地方是y = y.squeeze(1)这一行。在PyTorch中用交叉熵函数时,标签的形状应该为[N, A, B, ...],预测值的形状应为[N, num_class, A, B, ...]。其中,A,B, ...表示数据的形状。在我们的任务中,数据是二维的,因此标签的形状应为[N, H, W],预测值的形状应为[N, num_class, H, W]。而我们在DataLoader中获得的数据的形状是[N, 1, H, W]。我们要对数据y的形状做一个变换,使之满足PyTorch的要求。这里由于输入是单通道,我们可以随便用squeeze()y长度为1的通道去掉。如果图像是多通道的话,我们则不应该修改y,而是要对预测张量y_predict做一个reshape,改成[N, num_class, C, H, W]

第二个要注意的是y = torch.ceil(x * (color_level - 1)).long()这一行。为什么需要写一个这么复杂的浮点数转整数呢?这个地方的实现需要多解释几句。在我们的代码中,PixelCNN的输入可能来自两个地方:

  1. 训练时,PixelCNN的输入来自数据集。数据集里的颜色值是0~1的浮点数。
  2. 采样时,PixelCNN的输入来自PixelCNN的输出。PixelCNN的输出是整型(别忘了,PixelCNN只能产生离散的输出)。

两种输入,一个是0~1的浮点数,一个是0~color_level-1的整数。为了统一两个输入的形式,最简单的做法是对整型颜色输入做个除法,映射到0~1里,把它统一到浮点数上。

此外,还有一个地方需要类型转换。在训练时,我们需要得到每个像素的标签,即得到每个像素颜色的真值。由于PixelCNN的输出是离散的,这个标签也得是一个离散的颜色。而标签来自训练数据,训练数据又是0~1的浮点数。因此,在计算标签时,需要做一次浮点到整型的转换。这样,整个项目里就有两个重要的类型转换:一个是在获取标签时把浮点转整型,一个是在采样时把整型转浮点。这两个类型转换应该恰好「互逆」,不然就会出现转过去转不回来的问题。

在项目中,我使用了下图所示的浮点数映射到整数的方法。0.0映射到0,(0, 1/255]映射到1,……(254/255, 1]映射到255。即浮点转整型时使用ceil(x*255),整型转浮点的时候使用x/255。这种简单的转换方法保证一个区间里的离散颜色值只会映射到一个整数上,同时把整数映射回浮点数时该浮点数也会落在区间里。如果你随手把浮点转整型写成了int(x*255),则会出现浮点转整数和整数转浮点对应不上的问题,到时候采样的结果会很不好。

由于一个整型只能映射到一个浮点数,而多个浮点数会映射到一个整数,严格来说,大部分浮点数转成整数再转回来是变不回原来的浮点数的。这两个转换过程从数学上来说不是严格的互逆。但是,如果我们马虎一点,把位于同一个区间的浮点数看成等价的,那么浮点数和整数之间的映射就是一个双射,来回转换不会有任何信息损失。

刚才代码中y = torch.ceil(x * (color_level - 1)).long()这一行实际上就是在描述怎样把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的。

再来看看采样部分的代码。和正文里的描述一样,在采样时,我们把x初始化成一个0张量。之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def sample(model, device, model_path, output_path, n_sample=81):

model.eval()
model.load_state_dict(torch.load(model_path))
model = model.to(device)
C, H, W = get_img_shape() # (1, 28, 28)
x = torch.zeros((n_sample, C, H, W)).to(device)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist,
1).float() / (color_level - 1)
x[:, :, i, j] = pixel

imgs = x * 255
imgs = imgs.clamp(0, 255)
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(n_sample**0.5))

imgs = imgs.detach().cpu().numpy().astype(np.uint8)

cv2.imwrite(output_path, imgs)

整个采样代码的核心部分是下面这几行。我们先获取模型的输出,再用softmax转换成概率分布,再用torch.multinomial(prob_dist, 1)从概率分布里采样出一个0~(color_level-1)的离散颜色值,再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色),最后把新像素填入生成图像。

1
2
3
4
5
output = model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist,
1).float() / (color_level - 1)
x[:, :, i, j] = pixel

上面的代码中,如前所述,/ (color_level - 1)与前面的torch.ceil(x * (color_level - 1)).long()必须是对应起来的。两个操作必须「互逆」,不然就会出问题。

当然,最后得到的图像x是一个用0~1浮点数表示的图像,可以直接把它乘255变成一个用8位字节表示的图像,这一步浮点到整型的转换是为了让图像输出,和其他图像任务的后处理是一样的,和PixelCNN对于离散颜色和连续颜色的建模不是同一个意思,不是非得取一次ceil()

PixelCNN训练起来很慢。在代码中,我默认训练40个epoch。原版PixelCNN要花一小时左右训完,Gated PixelCNN就更慢了。

以下是我得到的一些采样结果。首先是只有8个颜色级别的PixelCNN和Gated PixelCNN。

可以看出,PixelCNN经常会生成一些没有意义的「数字」,而Gated PixelCNN生成的大部分数字都是正常的。但由于颜色级别只有8,模型偶尔会生成较粗的色块。这个在Gated PixelCNN的输出里比较明显。

之后看一下正常的256个颜色级别的PixelCNN和Gated PixelCNN采样结果。

由于颜色级别增大,任务难度变大,这两个模型的生成效果就不是那么好了。当然,Gated PixelCNN还是略好一些。训练效果差,与MNIST的特性(大部分像素都是0和255)以及PixelCNN对于离散颜色的建模有关。PixelCNN的这一缺陷已经在PixelCNN++论文里分析过了。

PixelCNN++ 源码阅读

PixelCNN++在实现上细节颇多,复现起来难度较大。而且它的官方实现是拿TensorFlow写的,对于只会PyTorch的选手来说不够友好。还好,PixelCNN++的官方实现非常简练,核心代码只有两个文件,没有过度封装,也没有过度使用API,哪怕不懂TensorFlow也不会有障碍(但由于代码中有很多科学计算,阅读起来没有障碍,却难度不小)。让我们来通过阅读官方源码来学习PixelCNN++的实现。

官方代码的地址在 https://github.com/openai/pixel-cnn 。源码有两个核心文件:nn.py实现了网络模块及一些重要的训练和采样函数,model.py定义了网络的结构。让我们自顶向下地学习,先看model.py,看到函数调用后再跑到nn.py里查看实现细节。

model.py里就只有一个函数model_spec,它定义了神经网络的结构。
它的参数为:

1
2
3
4
5
6
7
8
9
10
def model_spec(x, 
h=None,
init=False,
ema=None,
dropout_p=0.5,
nr_resnet=5,
nr_filters=160,
nr_logistic_mix=10,
resnet_nonlinearity='concat_elu',
energy_distance=False):

各参数的意义为:

  • x: 形状为[N, H, W, D1]的输入张量。其中,D1表示输入通道数。对于RGB图像,D1=3
  • h: 形状为[N, K]的约束条件,即对于每个batch来说,约束条件是一个长度K的向量。这里的约束条件和Gatd PixelCNN中提出的一样,可以是文字,也可以是类别,只要约束条件最终被转换成一个向量就行。
  • init: 是否执行初始化。这和TensorFlow的实现有关,可以不管。
  • ema: 对参数使用指数移动平均,一种训练优化技巧,和论文无关,可以不管。
  • dropout_p: dropout的概率。
  • nr_resnet: U-Net每一块里有几个ResNet层(U-Net一共有6块,编码器3块解码器3块)。
  • nr_filters: 每个卷积层的卷积核个数,即所有中间特征图的通道数。
  • nr_logistic_mix: 论文里的$K$,表示用几个logistic分布混合起来描述一个颜色分布。
  • resnet_nonlinearity: 激活函数的类别。
  • energy_distance:是否使用论文里没提过的一种算损失函数的办法,可以不管。

之后来看函数体。20行with arg_scope ([nn.conv2d, ...], counters=counters, ...)大概是说进入了TensorFlow里的arg_scope这个上下文。只要在上下文里,后面counters等参数就会被自动传入nn.conv2d等函数,而不需要在函数里显式传参。这样写会让后面的函数调用更简短一点。

22行至30行在选择激活函数,可以直接跳过。

1
2
3
4
5
6
7
8
9
# parse resnet nonlinearity argument
if resnet_nonlinearity == 'concat_elu':
resnet_nonlinearity = nn.concat_elu
elif resnet_nonlinearity == 'elu':
resnet_nonlinearity = tf.nn.elu
elif resnet_nonlinearity == 'relu':
resnet_nonlinearity = tf.nn.relu
else:
raise('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported')

从35行开始,函数正式开始定义网络结构。一开始,代码里有一个匪夷所思的操作:先是取出输入张量的形状xs,再根据这个形状给x填充了一个全是1的通道。

1
2
xs = nn.int_shape(x)
x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

虽然作者加了注释,说这个x_pad后面会用到。但我翻遍了代码,楞是没找到这个多出来的通道发挥了什么作用。GitHub issue里也有人提问,问这个x_pad在做什么。有其他用户给了回复,说他尝试了去掉填充,结果不变。可见这一行代码确实是毫无贡献,还增加了不必要的计算量。大概是作者没删干净过时的实现代码。

之后的几行是在初始化上卷积和左上卷积的中间结果(上卷积和Gated PixelCNN里的垂直卷积类似,左上卷积和Gated PixelCNN里的水平卷积类似)。u_list会保存所有上卷积在编码器里的结果,ul_list会保存所有左上卷积在编码器里的结果。这些结果会供解码器使用。

1
2
3
4
5
6
7
8
9
10
11
12
 u_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[2, 3])
)] # stream for pixels above
ul_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[1,3])
) + nn.right_shift(
nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1])
)] # stream for up and to the left

作者没有使用带掩码的卷积,而是通过普通卷积加偏移等效实现了掩码卷积。这一实现非常巧妙,效率更高。我们来看看这几个卷积的实现方法。

首先看上卷积down_shifted_conv2d,它表示实现一个卷积中心在卷积核正下方的卷积。作者使用了[2,3]的卷积核,并手动给卷积填充(注意,卷积的类型是'valid'不是'same')。这种卷积等价于我们做普通的3x3卷积再给上面6个像素打上掩码。

1
2
3
def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs):
x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]])
return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)

作者在down_shifted_conv2d之后跟了一个down_shift。这个操作和我们实现Gated PixelCNN时移动v_to_h张量的做法一样,去掉张量最下面一行,在最上面一行填0,也就是让张量往下移了一格。

1
2
3
def down_shift(x):
xs = int_shape(x)
return tf.concat([tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]],1)

类似地,在做第一次左上卷积时,作者把一个下移过的1x3卷积结果和一个右移过的2x1卷积结果拼到了一起。其中,down_right_shifted_conv2d就是实现一个卷积中心在卷积核右下角的卷积。

1
2
3
4
5
6
7
ul_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[1,3])
) + nn.right_shift(
nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1])
)]

初始化完毕后,数据就正式进入了U-Net。让我们先略过函数的细节,看一看模型的整体架构。在下采样部分,三级U-Net在每一级都是先经过若干个gated_resnet模块,再下采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

之后是上采样。类似地,数据先经过若干个gated_resnet模块,再上采样。与前半部分不同的是,前半部分的输出会从u_listul_list中逐个取出(实际上这两个list起到了一个栈的作用),接入到gated_resnet的输入里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
u = u_list.pop()
ul = ul_list.pop()
for rep in range(nr_resnet):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

for rep in range(nr_resnet+1):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

for rep in range(nr_resnet+1):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

模型U-Net的部分到此为止。整个网络的结构并不复杂,我们只要看懂了nn.gated_resnet的实现,就算理解了整个模型的实现。让我们来详细看一下这个模块是怎么实现的。以下是整个模块的实现代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
xs = int_shape(x)
num_filters = xs[-1]

c1 = conv(nonlinearity(x), num_filters)
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
c2 = conv(c1, num_filters * 2, init_scale=0.1)

# add projection of h vector if included: conditional generation
if h is not None:
with tf.variable_scope(get_name('conditional_weights', counters)):
hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
if init:
hw = hw.initialized_value()
c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])

a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)
return x + c3

照例,我们来先看一下函数的每个参数的意义。

1
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs)
  • x: 模块的输入。
  • a: 模块的附加输入。附加输入有两个来源:上方u_list的信息传递给左上方ul_list的信息、编码器把信息传递给解码器。
  • h: 形状为[N, K]的约束条件。从模型的参数里传递而来。
  • nonlinearity: 激活函数。从模型的参数里传递而来。
  • conv:卷积操作的函数。可能是上卷积或者左上卷积。
  • init: 是否执行初始化。这和TensorFlow的实现有关,可以不管。
  • counters: 作者写的一个用于方便地给模块的命名的字典,可以不管。
  • ema: 对参数使用指数移动平均。从模型的参数里传递而来。
  • dropout_p: dropout的概率。从模型的参数里传递而来。

模块主要是做了下面这些卷积操作。一开始,先对输入x做卷积,得到c1。如果有额外输入a,则对a做一个1x1卷积(作者自己实现了1x1卷积,把函数命名为nin),加到c1上。做完第一个卷积后,过一个dropout层。最后再卷积一次,得到2*num_filters通道数的张量。

1
2
3
4
5
6
7
c1 = conv(nonlinearity(x), num_filters)
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
c2 = conv(c1, num_filters * 2, init_scale=0.1)

之后,作者也使用了一种门结构作为整个模块的激活函数。但是和Gated PixelCNN相比,PixelCNN++的门结构简单一点。详见下面的代码。

1
2
a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)

最后输出时,c3和输入x之间有一个残差连接。

1
return x + c3

看完gated_resnet的实现,我们可以跳回去继续看模型结构了。经过了U-Net的主体结构后,只需要经过一个输出层就可以得到最终的输出了。输出层里,作者用1x1卷积修改了输出通道数,令最后的通道数为10*nr_logistic_mix

1
2
3
4
5
6
7
8
9
if energy_distance:
# 跳过
else:
x_out = nn.nin(tf.nn.elu(ul),10*nr_logistic_mix)

assert len(u_list) == 0
assert len(ul_list) == 0

return x_out

大家还记得这个10是从哪里来的吗?在正文中,我们曾经学过,对于某个像素的第$i$个logistic分布,网络会输出10个参数:$\pi, \mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$。这个10就是10个参数的意思。

光知道一共有10个参数还不够。接下来就是PixelCNN++比较难懂的部分——怎么用这些参数构成一共logistic分布,并从连续分布中得到离散的概率分布。这些逻辑被作者写在了损失函数nn.discretized_mix_logistic_loss里面。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def discretized_mix_logistic_loss(x,l,sum_all=True):
""" log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:,:,:,:nr_mix]
l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
means = l[:,:,:,:,:nr_mix]
log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])
x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)
centered_x = x - means
inv_stdv = tf.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = tf.nn.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = tf.nn.sigmoid(min_in)
log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5))))

log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)
if sum_all:
return -tf.reduce_sum(log_sum_exp(log_probs))
else:
return -tf.reduce_sum(log_sum_exp(log_probs),[1,2])

这个函数很长,很难读。它实际上可以被拆成四个部分:取参数、求均值、求离散概率、求和。让我们一部分一部分看过来。

首先是取参数部分,这部分代码如下所示。模型一共输出了10*nr_mix个参数,即输出了nr_mix组参数,每组有10个参数。如前所述,第一个参数是选择该分布的未经过softmax的概率logit_probs,之后的6个参数是三个通道的均值及三个通道的标准差取log,最后3个参数是描述通道间依赖关系的$\alpha, \beta, \gamma$。不用去认真阅读这段代码,只需要知道这些代码可以把数据取出来即可。

1
2
3
4
5
6
7
8
xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:,:,:,:nr_mix]
l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
means = l[:,:,:,:,:nr_mix]
log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])

之后是求均值部分。在第一行,作者用了一种曲折的方式实现了repeat操作,把x在最后一维重复了nr_mix次,方便后续处理。在第二第三行,作者根据论文里的公式,调整了G通道和B通道的均值。在最后第四行,作者把所有均值张量拼到了一起。

1
2
3
4
x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)

再来是求离散概率部分。作者根据论文里的公式,算出了当前离散分布的积分上限和积分下限(通过从累计分布密度函数里取值),再做差,得到了离散分布的概率。由于最终的概率值要求log,作者没有按照公式的顺序先算累计分布概率函数的值,再取log,而是把所有计算放到一起并化简。这样代码虽然难读了一点,但减少了不必要的计算,也减少了精度损失。

1
2
3
4
5
6
7
8
9
 centered_x = x - means
inv_stdv = tf.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = tf.nn.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = tf.nn.sigmoid(min_in)
log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases

作者还算了积分区间中心的概率,以处理某些边界情况。实际上这个值没有在代码中使用。

1
2
3
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in)
# log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

光做差还不够。为了处理颜色值在0和255的边界情况,作者还给代码加入了一些边界上的特判,才得到了最终的概率log_probs
1
2
3
4
5
log_probs = tf.where(x < -0.999, log_cdf_plus, 
tf.where(x > 0.999, log_one_minus_cdf_min,
tf.where(cdf_delta > 1e-5,
tf.log(tf.maximum(cdf_delta, 1e-12)),
log_pdf_mid - np.log(127.5))))

最后是loss求和部分。除了要把离散概率的对数求和外,还要加上选择这个分布的概率的对数。log_prob_from_logits就是做一个softmax再求一个log。算上了选择分布的概率后,再对loss求一次和,就得到了最终的loss。

1
2
3
4
5
log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)
if sum_all:
return -tf.reduce_sum(log_sum_exp(log_probs))
else:
return -tf.reduce_sum(log_sum_exp(log_probs),[1,2])

至此,我们就看完了训练部分的关键代码。我们再来看一看采样部分最关键的代码,怎么从logisitc分布里采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def sample_from_discretized_mix_logistic(l,nr_mix):
ls = int_shape(l)
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])
# sample mixture indicator from softmax
sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])
# select logistic parameters
means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)
log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)
coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))
x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)
x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)
x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)
return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)

一开始,还是和刚刚的求loss一样,作者把参数从网络输出l里拆出来。logit_probs是选择某分布的未经softmax的概率,其余的参数是均值、标准差、通道间依赖参数。

1
2
3
4
5
6
def sample_from_discretized_mix_logistic(l,nr_mix):
ls = int_shape(l)
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])

之后,作者对logit_probs做了一个softmax,得到选择各分布的概率。之后,作者根据这个概率分布采样,从nr_mix个logistic分布里选了一个做为这次生成使用的分布。作者没有使用下标来选择数据,而是把选中的序号编码成one-hot向量sel,通过乘one-hot向量来实现从某数据组里取数。

1
2
sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])

接着,作者根据sel,取出nr_mix个logistic分布中某一个分布的均值、标准差、依赖系数。

1
2
3
4
# select logistic parameters
means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)
log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)
coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)

再然后,作者用下面两行代码完成了从logistic分布的采样。从一个连续概率分布里采样是一个基础的数学问题。其做法是先求概率分布的累计分布函数。由于累计分布函数可以把自变量一一映射到0~1之间的概率,我们就得到了一个0~1之间的数到自变量的映射,即累积分布函数的反函数。通过对0~1均匀采样,再套入累积分布函数的反函数,就完成了采样。下面第二行计算其实就是在算logisitc分布的累积分布函数的反函数的一个值。

1
2
u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))

只从分布里采样还不够,我们还得算上依赖系数。把依赖系数的贡献算完后,整个采样就结束了,我们得到了RGB三个颜色值。

1
2
3
4
x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)
x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)
x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)
return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)

至此,PixelCNN++中最具有学习价值的代码就看完了。让我再次总结一下PixelCNN++中的重要代码,并介绍一下学习它们需要什么前置知识。

PixelCNN++中第一个比较重要的地方是掩码卷积的实现。它没有真的使用到掩码,而是使用了卷积中心在卷积核下方和右下角的卷积来等价实现。要读懂这些代码,你需要先看懂PixelCNN和Gated PixelCNN里面对于掩码卷积的定义,知道PixelCNN++为什么要做两种卷积。之后,你还需要对卷积操作有一点基础的认识,知道卷积操作的填充方式其实是在改变卷积中心在卷积核中的位置。你不需要懂太多TensorFlow的知识,毕竟卷积的API就那么几个参数,每个框架都差不多。

PixelCNN++的另一个比较重要的地方是logistic分布的离散概率计算与采样。为了学懂这些,你需要一点比较基础的统计学知识,知道概率密度函数与累积分布函数的关系,知道怎么用计算机从一个连续分布里采样。之后,你要读懂PixelCNN++是怎么用logistic分布对离散概率建模的,知道logistic分布的累计分布函数就是sigmoid函数。懂了这些,你看代码就不会有太多问题,代码基本上就是对论文内容的翻译。反倒是如果读论文没读懂,可以去看代码里的实现细节。

最近,我需要在Python里使用PatchMatch算法(一种算两张图片逐像素匹配关系的算法)。我去网上搜了一份实现,跑了下测试程序,发现它跑边长300像素的图片都要花三分钟。这个速度实在太慢了。我想起以前瞟到过一篇介绍Python加速的文章,里面提到过Numba这个库。于是,我现学现用,最终成功用Numba让原来要跑180秒的程序在0.6秒左右跑完。可见,Numba学起来是很快的。在这篇文章中,我将以这个Python版PatchMatch项目为例,介绍如何快速从零上手Numba,以大幅加速Python科学计算程序。这篇文章不会涉及PatchMatch算法的原理,只要你写过Python,就能读懂本文。

缘起:一份缓慢的 PatchMatch 实现

PatchMatch是Adobe提出的一种快速计算两张图片逐像素匹配关系的算法。也就是说,输入两张类似的图片A和B(比如视频里的连续两帧),算法能输出图片A中的每个像素对应B中的哪个像素(可能会出现多对一的情况)。为了快速验证算法的效果,我们可以输入图片A和B,用算法获取A到B的匹配关系,再根据匹配关系从B中取像素重建A。如果重建出来的图片和原来的图片A看上去差不多,那算法的效果就很不错。下图是一份PatchMatch测试程序的输出。

我在GitHub上找到了一份简明实用的Python版PatchMatch实现,得到了上面的输出结果。结果是挺不错,但哪怕是跑326x244这么小的图片,都要花约180秒才能跑完。

我懒得从头学一遍PatchMatch,决定直接上手优化代码。代码不长,其函数调用关系能很快理清。

代码入口函数是NNS()。它先是调用了initialization(),再循环itr次,每次遍历所有像素,对每个像素调用propagation()random_search()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
f, dist, img_padding = initialization(img, ref, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, False)
random_search(f, a, dist, img_padding, ref, p_size)
else:
for i in range(A_h):
for j in range(A_w):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, True)
random_search(f, a, dist, img_padding, ref, p_size)
return f

initialization()先是定义了一些变量,之后对所有像素调用cal_distance()

1
2
3
4
5
6
7
8
9
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
...
for i in range(A_h):
for j in range(A_w):
...
dist[i, j] = cal_distance(a, b, A_padding, B, p_size)
return f, dist, A_padding

propagation()主要调用了一次cal_distance()

1
2
3
4
5
6
7
8
9
10
11
12
def propagation(f, a, dist, A_padding, B, p_size, is_odd):
...
if is_odd:
if idx == 1:
...
dist[x, y] = cal_distance(a, f[x, y], A_padding, B, p_size)
if idx == 2:
...
dist[x, y] = cal_distance(a, f[x, y], A_padding, B, p_size)
else:
# 和 is_odd 时类似
...

random_search()则主要是在一个while循环里反复调用cal_distance()

1
2
3
4
5
6
def random_search(f, a, dist, A_padding, B, p_size, alpha=0.5):
...
while search_h > 1 and search_w > 1:
...
d = cal_distance(a, b, A_padding, B, p_size)
...

最后来看被调用最多的cal_distance()。这个函数用于计算图片A,B之间的某个距离。也别管这个距离是什么意思,总之是这一个有点耗时的计算。

1
2
3
4
5
6
7
8
9
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
temp = patch_b - patch_a
num = np.sum(1 - np.int32(np.isnan(temp)))
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

至此,这份程序就差不多看完了。可以发现,代码大部分时候都在遍历像素,且遍历每个像素时多次调用cal_distance()函数。而我们知道,拿Python本身做计算是很慢的,尤其是在一个很长的循环里反复计算。这份代码性能较低,正是因为代码在遍历每个像素时做了大量计算。

我以前看过一篇文章,说Numba库能够加速Python科学计算程序,尤其是加速带有大量循环的程序。于是,我去学习了一下Numba的基础用法。

Numba 基础

Numba的官方文档提供了非常友好的入门教程。我们来大致把教程过一下。

Numba可以用pip一键安装。

1
pip install numba

Numba尤其擅长加速循环以及和NumPy相关的计算。使用@jit(nopython=True)(或@njit)装饰一个函数后,我们可以在这个函数里随便写循环,随便用NumPy计算,就像在用C语言一样。经Numba优化后,这个函数会跑得飞快。以下是官方给出的入门示例程序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from numba import jit, njit
import numpy as np

x = np.arange(100).reshape(10, 10)


@jit(nopython=True) # 设置 "nopython" 模式以获取最优性能,等价于 @njit
def go_fast(a): # 初次调用时函数将被编译成机器码
trace = 0.0
for i in range(a.shape[0]): # Numba 喜欢循环
trace += np.tanh(a[i, i]) # Numba 喜欢 NumPy 函数
return a + trace # Numba 喜欢 NumPy 广播


print(go_fast(x))

Numba是怎么完成加速的呢?从装饰器名jit(JIT,Just-In-Time Compiler的简称)中,我们能猜出,Numba使用了即时编译技术,把函数直接翻译成了机器码,而没有像普通Python程序一样解释执行。Numba有两种编译模式,最常见的模式是令参数nopython=True,在编译中完全不用Python解释器。这种模式下,函数能以最优性能翻译成机器码。

修改上面的代码,我们可以测试该函数的速度。注意,由于采用了即时编译,函数在初次调用时会被编译。如果只要计算函数在编译后的运行时间,应该从第二次调用后开始计时。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from numba import jit
import numpy as np
import time

x = np.arange(100).reshape(10, 10)


@jit(nopython=True)
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace


# 不要汇报这个速度,因为编译时间也被算进去了
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
print("Elapsed (with compilation) = {}s".format((end - start)))

# 现在函数已经被编译了,用缓存好的函数重新计时
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
print("Elapsed (after compilation) = {}s".format((end - start)))

我们可以得到类似于下面的输出:

1
2
Elapsed (with compilation) = 1.0579542s
Elapsed (after compilation) = 1.7699999999898353e-05s

我们可以尝试一下不用Numba,直接用Python循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
import time

x = np.arange(100).reshape(10, 10)


def go_slowly(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace


start = time.perf_counter()
go_slowly(x)
end = time.perf_counter()
print("Elapsed (without Numba) = {}s".format((end - start)))

这个速度(4e-4)比用Numba慢了一个数量级。

1
Elapsed (without Numba) = 0.00046979999999985367s

也就是说,我们只要在普通的Python计算函数上加一个@jit(nopython=True)(或@njit),其他什么都不用做,就可以加速代码了。让我们来用它改进一下之前的PatchMatch程序。

用Numba计时编译加速PatchMatch

让我们开始做PatchMatch的性能调优。首先,根据性能优化的一般做法,我们要得知每一行函数调用的运行时间,找到性能瓶颈,从瓶颈处开始优化。我们可以用line_profiler来分析每一行代码的运行时间。用pip即可安装这个库。

1
pip install line_profiler

把主函数修改一下,在调用算法入口函数时拿LineProfiler封装一下,再用lp.add_function添加想监控的函数,即可开始性能分析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if __name__ == "__main__":
img = np.array(Image.open("./cup_a.jpg"))
ref = np.array(Image.open("./cup_b.jpg"))
p_size = 3
itr = 5

# start = time.time()
# f = NNS(img, ref, p_size, itr)
# end = time.time()
# print(end - start)

# reconstruction(f, img, ref)

lp = LineProfiler()
lp_wrapper = lp(NNS)
lp.add_function(propagation)
lp.add_function(random_search)
f = lp_wrapper(img, ref, p_size, itr)
lp.print_stats()

性能分析结果会显示每一行代码的运行时间及占用时间百分比。从结果中可以看出,在入口函数NNS()中,random_search()最为耗时。这是符合预期的,因为random_search()里还有一层while循环。

现在,我们应该着重优化random_search()的性能。我们继续查看一下random_search()的性能分析结果。

结果显示,绝大多数时间都消耗在了while循环里。也和我们之前分析得一样,cal_distance()是耗时最多的一行。除了random_search()外,其他几个函数也多次调用了cal_distance()。因此,我们目前代码优化的目标就定格在了cal_distance()身上。

刚刚学完了Numba,这不正好可以用上了吗?我们可以尝试直接给cal_distance()加一个@njit装饰器。

1
2
3
4
5
6
7
8
9
@njit
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
temp = patch_b - patch_a
num = np.sum(1 - np.int32(np.isnan(temp)))
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

修改完代码后,再次运行程序。这次,程序报了一大堆错误。大致是说,在某一行碰到了Numba识别不了的函数。应该把np.int32()的强制类型转换改成.astype(np.int32)

1
2
num = np.sum(1 - np.int32(np.isnan(temp)))
^

改完之后,如果Numba版本较老,还会碰到新的报错:

1
Use of unsupported NumPy function 'numpy.nan_to_num' or unsupported use of the function.

报错显示,numpy.nan_to_num函数没有得到支持。再次翻阅Numba文档,可以发现,Numba并不支持所有NumPy函数。Numba对NumPy的支持情况可以在文档里查询(需要把文档切换到你当前Numba的版本)。

总之,cal_distance()这个函数不改不行了。得认真阅读一下这个函数的原理。原来,cal_distance(a, b, A_padding, B, p_size)函数是算图像A_padding和图像B中某一个像素块的均方误差的平均值,其中,像素块的边长为p_size,像素块在A_padding的坐标由a表示,在B中的坐标由b表示。

1
2
3
4
5
6
7
8
9
10
11
12
13
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
# 根据坐标a和边长p从A_padding里取像素块
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
# 根据坐标b和边长p从A_padding里取像素块
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
# 求差
temp = patch_b - patch_a
# 根据非nan像素数量算有效像素数量
num = np.sum(1 - np.isnan(temp).astype(np.int32))
# 排除nan,求差的平方和,再除以有效像素数量
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

代码里还有一些奇怪的有关nan的运算:如果像素块里某处有nan,就说明此处像素无效,不应该参与均方误差的运算。为什么图像里会有nan呢?我们得阅读代码的其他部分。

nan是在初始化函数initialization()里加入的。A_padding原来是图像A在周围填了一圈nan的结果。我们大致能猜测出作者填充nan的原因:从A中取像素块时,若像素块在边缘,则有一些像素就不应该被计算了。拿条件语句判断这些无效像素比较麻烦,作者选择干脆在图像A周围填一圈nan,保证每次取像素块时不用判断无效像素。等算误差的时候再判断根据nan排除无效像素。

1
2
3
4
5
6
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
p = p_size // 2
A_padding = np.ones([A_h + p * 2, A_w + p * 2, 3]) * np.nan
A_padding[p:A_h + p, p:A_w + p, :] = A

使用nan填充,既耗时,兼容性又不好。为了尽可能加速cal_distance(),我把填充改成了edge填充,即让填充值等于边界值,并取消了无效像素的判断。也就是说,若像素块取到了图像外的像素,则认为这个像素和边界处的像素一样。这个假设是很合理的,这种修改几乎不会损耗算法的效果。

除此之外,为了进一步减少cal_distance()中的计算,我把要用到的变量都提前在外面算好再传进来。由于现在不需要考虑无效像素的数量,可以直接对误差求和,不用再算平均值,少做一次除法。还有,现在用@njit装饰了函数,可以放心大胆地在循环里做计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
p = p_size // 2
A_padding = np.pad(A, ((p, p), (p, p), (0, 0)), mode='edge')

# Numba 循环写法
@njit
def cal_distance(x, y, x2, y2, A_padding, B, p):
sum = 0
for i in range(p + p + 1):
for j in range(p + p + 1):
for k in range(3):
a = float(A_padding[x + i, y + j, k])
bb = B[x2 - p + i, y2 - p + j, k]
sum += (a - bb)**2
return sum

当然,用NumPy实现cal_distance也是可以的。

1
2
3
4
5
# NumPy 等价写法,加上@njit更快
def cal_distance(x, y, x2, y2, A_padding, B, p):
patch_a = A_padding[x:x + p, y:y + p, :].astype(np.float32)
patch_b = B[x2 - p:x2 + p + 1, y2 - p:y2 + p + 1, :]
return np.sum((patch_a - patch_b)**2)

经测试,把nan的判断全部去掉后,使用NumPy版的cal_distance(),程序的运行时间降到了60秒。给NumPy版的cal_distance()加上@njit,运行时间进一步降低到了33秒。而如果使用带@njit装饰的循环写法,则运行时间也差不多是33秒,甚至还略快一些。这些测试结果印证了Numba的特性:

  1. Numba可以加速和NumPy张量相关的计算
  2. 在Numba中使用循环不会降低运行速度

成功用@njit优化完了代码中最深层的cal_distance(),我们会想,是不是所有函数都可以用同样方法加速?我们可以来做个实验,给最外层的入口函数NNS()加上@njit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@njit
def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
f, dist, img_padding = initialization(img, ref, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, False)
random_search(f, a, dist, img_padding, ref, p_size)
else:
for i in range(A_h):
for j in range(A_w):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, True)
random_search(f, a, dist, img_padding, ref, p_size)
return f

运行程序,会得到类似于下面的报错:

1
Untyped global name 'initialization': Cannot determine Numba type of <class 'function'>

把报错放网上一搜,原来,@njit的自定义函数只能调用加@njit的自定义函数。也就是说,在上面这份代码里,我们虽然用@njit装饰了NNS(),但我们自己定义的initialization(), propagation(),random_search()全部都没有用@njit装饰,因此NNS()的编译会出错。看来,我们得自底向上一步一步加上@njit了。

先来尝试修改一下initialization()。很可惜,直接加上@njit会报错。

1
2
3
4
5
6
7
8
9
10
11
12
13
@njit
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
B_h = np.size(B, 0)
B_w = np.size(B, 1)
p = p_size // 2
random_B_r = np.random.randint(p, B_h - p, [A_h, A_w])
random_B_c = np.random.randint(p, B_w - p, [A_h, A_w])
...

#报错
# Use of unsupported NumPy function 'numpy.size' or unsupported use of the function.

报错是说有不支持的NumPy函数numpy.size。实际上,不仅是numpy.size,Numba也不支持有三个参数的np.random.randint。为了解决此问题,和刚刚对numpy.nan_to_num的处理一样,最好是能用其他等价写法来代替不支持的函数。如果不行的话,则应该把不支持的运算和支持的运算分离开,只加速支持的那一部分。对于initialization(),我采用了第二种解决方法,把函数中耗时的循环拆开来单独用@njit装饰,其余有不支持的NumPy函数的部分就不用Numba优化了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@njit
def initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p):
for i in range(A_h):
for j in range(A_w):
x, y = random_B_r[i, j], random_B_c[i, j]
f[i, j, 0] = x
f[i, j, 1] = y
dist[i, j] = cal_distance(i, j, x, y, A_padding, B, p)


def initialization(A, B, A_h, A_w, B_h, B_w, p_size):
p = p_size // 2
random_B_r = np.random.randint(p, B_h - p, [A_h, A_w])
random_B_c = np.random.randint(p, B_w - p, [A_h, A_w])
A_padding = np.pad(A, ((p, p), (p, p), (0, 0)), mode='edge')
f = np.zeros([A_h, A_w, 2], dtype=np.int32)
dist = np.zeros([A_h, A_w])
initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p)
return f, dist, A_padding

另外的两个函数propagation()random_search()只会碰到取形状函数numpy.size的问题。这个问题很好解决,只要把numpy.size挪到函数调用外即可。

initialization_loop()propagation()random_search()都加上@njit后,程序的运行时间从33秒猛地降到了3秒左右。可以说,只用加@njit的方法的话,程序已经没有优化空间了。

用Numba提前编译加速PatchMatch

又看了看PatchMatch的源码,我发现,PatchMatch算法会先为每个像素随机生成一个匹配关系。然后,算法会迭代更新匹配关系。迭代得越久,匹配关系越准。而我之后要用PatchMatch处理一段视频,算所有帧对第1帧的匹配关系。那么,对于视频这种连续的图像序列,我能不能让第3帧初始化匹配关系时复用第2帧的匹配结果,第4帧复用第3帧的匹配关系,以此类推,以减少迭代次数呢?

说干就干。我准备先测试一下减少迭代次数后代码运行时间能缩短多少。迭代次数itr是在main函数里指定的,作者默认的数值是5。我把它改成1测试了一下。

1
2
3
4
5
6
7
8
9
10
11
if __name__ == "__main__":
img = np.array(Image.open("./cup_a.jpg"))
ref = np.array(Image.open("./cup_b.jpg"))
p = 3
itr = 1

start = time.time()
f = NNS(img, ref, p, itr)
end = time.time()
print(end - start)
reconstruction(f, img, ref)

结果,原来要花3秒的程序还是要花接近3秒,时间缩短得非常不明显。这不应该啊,理论上程序的运行时间应该大致和itr成正比啊。

测试了半天,我突然想起Numba文档里讲过,@njit是即时编译,函数的编译会在初次调用函数时完成。我每次运行程序时,大部分时间都花在了编译上,因此整个程序的运行时间几乎不由迭代次数决定。

我之后要反复运行PatchMatch程序,而不是通过运行一次程序来处理大批数据。即时编译的代价我是接受不了的。于是,我去文档里找到了Numba提前编译(AOT,ahead of time)的使用方法。

Numba AOT可以把Python函数编译进一个模块文件中。想在其他地方调用被编译的函数时,只需要import 模块名即可 。

官方给出的Numba AOT示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from numba.pycc import CC

cc = CC('my_module')

@cc.export('multf', 'f8(f8, f8)')
@cc.export('multi', 'i4(i4, i4)')
def mult(a, b):
return a * b

@cc.export('square', 'f8(f8)')
def square(a):
return a ** 2

if __name__ == "__main__":
cc.compile()

首先,程序要用一个模块名实例化一个CC。该模块名是未来我们import时用到的名称。之后,对于想编译的函数,我们要用@cc.export装饰它。@cc.export的第一个参数是调用时的函数名(原来的函数名会被舍弃),第二个参数用于指定函数返回值和参数的类型。做完所有这些准备后,使用cc.compile()即可完成编译。

运行该程序,会得到一个模块文件。根据平台的不同,该模块文件名可能是my_module.somy_module.pydmy_module.cpython-34m.so。不管文件名是什么,只要是在同一个文件夹下,我们就可以用下面的Python命令调用这个模块文件。

1
2
3
4
5
>>> import my_module
>>> my_module.multi(3, 4)
12
>>> my_module.square(1.414)
1.9993959999999997

用Numba做即时编译时,函数的返回值类型和参数类型可填可不填。而Numba提前编译中,必须要填入函数的返回值类型和参数类型。这让编写Numba提前编译的工作量大了不少,已经不像是在写Python,而是在写C了。

还有一点值得注意。和使用即时编译时一样,自定义的函数在调用其他自定义函数时,必须要加上@njit。所以,会出现一个函数即有@njit,又有@cc.export的情况。

学习使用Numba提前编译时,最主要是要学习Numba是怎么用字符串代表参数类型的。比如,i4是32位整型,u1是8位无符号整型,u1[:, :, :]是三维8位无符号整型,void是无返回值。这些表示可以在官方文档里找到。

以我写的Numba AOT PatchMatch的编译代码为例,我们可以看一看参数类型和返回值类型是怎么描述的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
from numba import njit
from numba.pycc import CC

cc = CC('patch_match_module')


@njit
@cc.export('cal_distance', 'f4(i4, i4, i4, i4, u1[:, :, :], u1[:, :, :], i4)')
def cal_distance(x, y, x2, y2, A_padding, B, p):
...


@njit
@cc.export(
'initialization_loop',
'void(u1[:, :, :], u1[:, :, :], i4[:, :, :], f4[:, :], i4, i4, i4[:, :], i4[:, :], i4)'
)
def initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p):
...


@njit
@cc.export(
'propagation',
'void(i4[:, :, :], i4, i4, i4, i4, f4[:, :], u1[:, :, :], u1[:, :, :], i4, b1)'
)
def propagation(f, x, y, A_h, A_w, dist, A_padding, B, p_size, is_odd):
...


@njit
@cc.export(
'random_search',
'void(i4[:, :, :], i4, i4, i4, i4, f4[:, :], u1[:, :, :], u1[:, :, :], i4, f4)'
)
def random_search(f, x, y, B_h, B_w, dist, A_padding, B, p_size, alpha=0.5):
...


if __name__ == "__main__":
cc.compile()

运行该程序后,在我的电脑上得到了名为patch_match_module.cp37-win_amd64.pyd的模块文件。可以在其他代码里通过import patch_match_module调用编译好的函数了,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import patch_match_module

def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
B_h = np.size(ref, 0)
B_w = np.size(ref, 1)
f, dist, img_padding = initialization(img, ref, A_h, A_w, B_h, B_w, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
patch_match_module.propagation(f, i, j, A_h, A_w, dist,
img_padding, ref, p_size,
False)
patch_match_module.random_search(f, i, j, B_h, B_w, dist,
img_padding, ref, p_size,
0.5)
else:
for i in range(A_h):
for j in range(A_w):
patch_match_module.propagation(f, i, j, A_h, A_w, dist,
img_padding, ref, p_size,
True)
patch_match_module.random_search(f, i, j, B_h, B_w, dist,
img_padding, ref, p_size,
0.5)
return f

加上最后这步提前编译后,PatchMatch的运行时间从3秒降低到了0.6秒多。程序从最开始的180秒降到了0.6秒,几乎快了300倍。而且,如果是处理视频,还可以通过复用前一帧信息来减少迭代次数,进一步缩短每一帧的平均处理时间。能加速这么多,并不是我太强,而是Python实在太慢了。纯Python就不应该用来写科学计算程序。

总结

通过阅读这篇文章,相信大家能根据我这次Python PatchMatch性能优化经历,在不阅读Numba文档的前提下自然而然地学会Numba的用法。我把文章中提到的和Numba性能优化有关的知识点按使用顺序总结一下。

  1. 在面向应用的程序中,不要用Python写科学计算程序。哪怕要写,也要尽可能避免在循环中使用大量计算,而是去调用各个库的向量化计算。
  2. 直接在想优化的函数前加@njit装饰。在待优化函数里使用循环、NumPy函数都是很欢迎的。
  3. 如果碰到了Numba不支持的函数,可以通过两种方式解决:1)用等价的Numba支持的函数代替;2)把不支持和支持的部分分离,只加速支持的部分。
  4. 一个带@njit函数在调用另一个自定义的函数时,那个函数也得加上@njit。因此,应该自底向上地实现Numba即时编译函数。
  5. 如果你接受不了计时编译的编译时间,可以使用提前编译技术。使用提前编译时,主要的工作是给参数和返回值标上正确的类型。

Numba确实很容易上手,只要会加@njit,剩下碰到了什么问题去搜索一下就行。Numba的官方文档很详细,想深入学习的话直接看文档就行了。

本项目的代码仓库为:https://github.com/SingleZombie/Fast-Python-PatchMatch 。在原作者仓库的基础上,我添加了PatchMatch_numba_jit.pyPatchMatch_numba_compile.pyPatchMatch_numba_aot.py这三个文件。它们分别表示即时编译运行程序、提前编译编译程序、提前编译运行程序。

经历

学校申请在2023年1月31日截止,托福成绩要求100分。

我在两年前读本科时考过托福,阅读/听力/口语/写作的分数分别为26/25/21/22,共94分。在这两年里,我的英语听说读写水平均有提升。大致估计一下,我应该可以轻轻松松地多考6分,拿到100+的成绩。

带着这样的想法,我2022年11月才开始准备托福。准备之前,我先订下了复习的策略。由于我不知道现在的考试水平,只能以两年前的成绩为基础,思考各部分的提分策略。阅读、听力都是客观题,很容易自学,各提个2分非常简单。据很多人反映,口语很难靠自学提升,那我只要保持之前的21分即可。这样,要让总分超过100,这次考试最重要的就是写作了。对于看重英语读写教育的中国学生来说,英语写作的基础在高中时就学得差不多了。而针对托福考试的写作技巧,则可以快速学会。因此,我应该花时间集中攻关托福写作技巧。综上,我的复习策略总结为:稳步提高阅读与听力,保持口语,攻关写作。

11月,我一边正常工作,一边断断续续地做了阅读和听力的TPO真题,再次熟悉了这两部分的应试方法。同时,我也做了一两套口语题,熟悉了题型,保证听完题目后能张口答题。最后一周,我开始集中学习写作技巧。网上的托福「教程」全是打广告的垃圾信息。我从互联网垃圾堆里翻出了一篇论坛上的经验分享贴,下载了一本叫做《十天突破新托福写作》的书。这本书虽然成书于2010年,但其内容十分全面,毫不过时。我认真学习了一些能够短期内掌握的写作技巧。

慢吞吞地准备了一个月,差不多了。我把第一场托福家考的时间订在12月5日早上9:50。我不方便把考试时间订在我最清醒的晚上,又考虑到我有以前线下考试的经验,才把时间安排在了早上。另外,12月5日是星期一,在此之前我能有一整个周末的时间来做最后的准备。如果一切顺利的话,我可以在不影响正常工作的前提下于周一下午结束托福考试。

第一场考试即将开始。我知道自己准备得不够好,却也因此做好了多考几次的打算,只把这次考试当成了一次分数测试,没有什么心理压力。托福家考的考官会对用作考场的个人房间进行检查。和考官的几句英文对话,也恰到好处地缓解了我的紧张,令我的大脑迅速进入了英语环境。克服了刚开始动脑时打的几个哈欠后,我顺顺利利地完成了考试。

和普通的线下托福机考一样,家考结束后,只有客观题的阅读和听力会立刻出分。我迫不及待地点下查分按钮后,看到的是一个令我傻眼的分数:阅读27,听力19。

我曾经为查分结果做出多种假设,以决定下一次考试的时间。如果阅读和听力的分数特别好,我就等所有分数出了再看要不要继续考试;如果阅读和听力的分数较差,我就再学一阵,学得差不多了就考。无论如何,看到分数后,我的心情要么是欣喜,要么是失落。可是,看到这只有19分的听力,我先是撇嘴笑了笑,随后皱起了眉头。一阵不服气的火焰从我心中燃起。种种迹象表明,我的英语听力水平没有那么差。在这次听力考试时,我每篇材料也听得津津有味。怎么能考出这样的分数出来呢?这已经不是我有问题了,这绝对是因为托福出题方ETS出题不力,不能让题目反映应试者水平。空口无凭,下一次,我会用我的无懈可击的表现来证明托福听力的不合理。

我也不管这次的分数怎么样了,准备在最快的时间里开始下一次考试。根据最新的规定,3天之后就能考下一场托福。但是,在临考一周内报名要交额外的费用。于是,我决定在不缴纳晚考费的第一天,也就是下周一的同一时间,再战托福。在这一周,我会把我的托福听力水平提升到最强的境界。考完了,如果听力考出高分,我还要写一篇文章来总结托福听力经验,把托福听力狠狠地踩在地上。接下来一周的剧本,已经准备就绪。

自我上进并不是令人奋斗的理由。只有对自我或外界的不满而产生的复仇心理,才能催促人不断前进。我的拖延被立刻治好。第一场考试结束当晚,我就精神饱满地重新开始了托福听力的准备。本来我想把这整周都用在听力的练习上。可是,由于学习速度过快,练了三天,我就发现我的听力水平大有提升,以至于没有什么进步空间了。于是,最后几天,我还是去找了找口语和写作的状态,为第二次考试准备。

12月12日,周一,早上9:50,为了完成对托福听力的复仇,我又杀回来了。由于是第二次参加托福家考,一切安检的流程我都轻车熟路。十分钟内,考官就检查完了资料和考试环境,第二场考试开始了。

第一部分是阅读。我平时做阅读总是会错几题,尤其是最后的六选三大题。但是,我对此毫不在意。因为我知道,平时我阅读出错,完全是因为我不认真,太浪。如果换成真正的考试,我就会拿出全部的实力,认真记忆文章,根本不会错那么多题。事实也确实如此。在第二次阅读考试中,我已经适应了考场上做阅读的策略,做题做得心应手。72分钟,4篇文章,我只需要每篇文章都花18分钟即可。每篇文章我都留3分钟来解决最后的六选三大题。把所有题做对不成问题。

之后是听力。我的状态和上周差不多,略有不同的是,我记笔记的方法更好了,信心更足了。口语和综合写作的发挥也和上周差不多。最后一项任务是独立写作,其话题稍微有点难度:「政府要建拆旧建筑修新建筑,一批人要被迫搬到新的地方,你同意吗?」我半天没有找到特别通畅的写作思路,只好硬着头皮分个人、企业、政府三个方面扯了一堆反对的理由。

考试结束,又到了开奖时间。确认考试结束的画面我看都不看,狂点着下一步,迫不及待地查看起阅读和听力的分数:

阅读28,听力27。

多么令人惊喜的分数!这两部分的分数均创我个人新高,成功达成了「阅读、听力各高2分」的目标。另外,我前一周吹的牛也成真了。我确实考出了一个不错的听力成绩,有资格去写托福听力准备攻略了。再次确认分数没有看错后,我在房间里大笑不止。计划的顺利执行、复仇的完美实践、成功的如期而至、心腹大患已除的如释重负。种种喜讯,充斥着我的大脑,让我只能以大笑来释放那高悬已久的紧张心情。我一边笑,一边狂妄地叫道:「如果这个阅读和听力的分数都不能让我总分上100,我就去炸了ETS!」我觉得这次总分肯定有100分了。把两年前的口语阅读分搬过来,就已经有98分了。这两部分只要稍微提高一点,100分就有了。

两天后的早上,也就是第一场考试结束后的第九天,ETS寄来邮件,说第一场考试的成绩已出。我一看,阅读27,听力19,口语23,写作26,总分95。两年前,我的口语和写作分别是21分和22分。这样一看,我口语和写作的提升确实不小。把昨天刚考出来的阅读和听力在拼起来,总分就有104分了。相比上周的发挥,我有足足4分的退步空间。我觉得我这次考试上100分已经是板上钉钉的事情了。

我兴冲冲地和领导讲,托福已经考完了,我从这周开始可以正常干活了。用于做笔记的白板被我扔到房间一角,为了通过考试安检而清理的房间又被我弄乱,羞于让考官看到的二次元电脑桌面又被我换回。生活重回正轨,仿佛什么都没有发生过。我写了一篇托福听力准备攻略,以此为送别托福的赠礼。

放松了一周,又到了周一早上。总算,我不用再面对9:50开始的托福考试,可以睡个懒觉,再去学校。本来是轻松快乐的一天,不知怎地,我怎么都提不起劲,仿佛是背后被锁链扯住了。到晚上临睡时,不安感越来越强。我这才想起来,明天是考试后第八天。最快的情况下,明天就会出分了。在最终分数出来之前,一切还没有结束。

周二早上9点,我在睡醒后立刻点开了邮箱。果然,上次的托福分数已经出了。我连忙点进查分页面,第一眼看到了一个三位数的分数。我按捺住激动的心情,仔细一看,这个分数是104分。原来这是TOEFL MyBest Score,也就是各个小分的最优拼分结果。和我算得一样,拼分结果是104分。现在庆祝还为时尚早,我把眼睛往上挪了挪,看到了我最不想见到的结果——一个只有两位数的总分。定睛一看,上次考试的总分是99分!口语、写作都只考了22分,这导致最终的总分刚好离及格线100分差1分。

反复确认学校不接受拼分结果后,我的心凉了下来。无数的负面想法萦绕在我的脑中:「考前怎么不多准备一下口语和写作」、「为什么这么晚才考试」、「考试的时候状态好一点就好了」……。但是,现在,连多余的自责都是没有意义的。我要做的,只有尽力考好下一次托福。我也没有思考,照着上次的习惯,报考了下周二早上9:50的托福家考。一切从头开始。

这一周,我完全失去了活力。每天上午,楼下都在轰轰隆隆地装修。我晚上本来就睡得晚,白天多躺一下的权利也失去了。生活中又突然碰到了一些事,我不得不花了一两天去处理。好不容易到了周末,我才有心思去准备一下托福。可是,我一看到屏幕上的做题界面,就感到全身有无数根线在撕扯着我,让我快点结束做题。

截至周一,考试前一天,我这周几乎没有花时间准备考试,生物钟也乱得一塌糊涂。我想,最后一天了,准备考试也来不及了,就把生物钟调整一下吧。我找事情消磨着时间,希望自己不要在白天补觉了。下午,吃过午饭,我实在困得受不了,就靠在床头,心想:「就靠着休息一下吧。靠着睡睡得不舒服,睡一下应该就能醒来。」但当我醒来的时候,天已经黑了。我已经困到连靠着睡也能睡四个小时了。根据前几天的经验,我知道今天晚上睡不着了。

睡不了觉,那就用最后的时间准备一下考试吧。按我以前的复习习惯,考试前最后准备的应该是口语。然而,听完一道口语的题面,花完15秒的准备时间,当录音开始的一瞬间,我发现自己张嘴发不出声音。原来,在准备时间里,我的大脑一片空白,根本组织不出英语句子。既然如此,开口讲一些乱七八糟的话,也是浪费时间。巨大的焦虑,把开口讲英语也变成了一件困难的事。

什么事都做不下去,我只好关灯,上床,躺着,一秒一秒迎接末日的来临。在最应该睡觉的夜晚,我的大脑却无比清醒。我糊里糊涂地想了很多事情。最后也想开了,这次考试肯定是考不过了。今天是12月27日,离1月31日还早,还有机会准备。我爬起来想去取消这次考试,挽回一些考试费,却发现考试早就定下来不允许取消了。没办法,就当这次考试是调整心情吧。相比白白花了钱不去考试,还是参加一下比较好。

早上七点,天亮了。我总算困了,睡了下去。曾经慢慢吞吞令我倍感煎熬的时钟,此刻却又卯足了劲飞奔起来。没过多久,闹钟就响了,提醒我9:50时有一件非做不可的事。9:30,我艰难地起了床,准备迎接第三次考试。

我从来没有喝醉过酒,此刻,我却能深刻理解喝醉酒的痛苦。听说酒喝多了会呕吐,那么,止不住的呕吐,应该和止不住的哈欠是一样的吧。明知困得睁不开眼,却要在不断的哈欠中强打精神。这种折磨不是和明知胃里没有东西却不得不重复地呕吐一样吗?每次早上考托福时,我其实睡得都不太够。做一开始的阅读题时,精神又必须得完全集中。所以,每次考托福,我都会打一阵子哈欠,花一些时间让大脑适应。而今天,我的睡眠时间实在太短了,以至于哈欠不断,大脑久久不能集中。还是4篇阅读,72分钟,平均每篇18分钟。第一篇阅读我恰好花18分钟做完。做第二篇时,我被一道题卡住了,多花了足足6分钟。后面两篇文章我只好平均花15分钟飞快地做完。做完阅读,我就知道这次考试已经没救了。随后,我又勉强做完了听力题。到了口语考试时,我糟糕的状态就一发不可收拾了。和前一天晚上的表现一样,我在口语的准备时间里非常慌张,根本动不了脑子。等开始录音了,我才开始边想边讲,凭借本能让英语单词一个一个从嘴里蹦出来。最后,到了写作的时候,我的脑子总算清醒了。结果,独立写作来了道题面很长,题材很怪的题。我根本想不出这个话题可以从哪三个角度来讲,只好硬凑字数,第一次写了一篇只有两个主体段的文章。

考试结束,我漠然地查看客观题成绩:阅读25,听力22。「还好,至少听力比第一次的19分高。」反正这次考试肯定考不过100,我只能苦中作乐,怎么乐观怎么想了。确认考试结束后,我如释重负,往床上猛地一躺。

差不多是上周的这个时候,我确认了托福还要再考。当时,我曾用一张薄纸封住了内心的负面想法,只希望下次考试能够尽力。虽然我无时无刻不在被透过纸的针扎痛,但我一直扛到了现在。如今,已经没有任何克制的需要了。「为什么不早点准备?」、「为什么不好好准备?」、「为什么会差1分?」、「为什么这周状态那么差」、「为什么……?」、「为什么……?」、「为什么……?」、……。新账加上旧账,弯刀接着利剑。无尽的责问,让疲惫的我想躺在床上就此昏睡过去。可是,一想到没能好好睡觉也是失利的败因之一,我又立刻咬牙坐了起来。疲倦与饥饿接踵而至,我决定先去吃饭。

吃完饭,我去超市里买生活用品。没想到,几天不来,超市从白与绿的圣诞风格,变成了以红为主的春节风格。我这才发现,圣诞节已经在不经意间过去了。再这样考不过,新年也是同样的过法吧。在这一个多月,我除了准备托福和消磨时间,别的什么都没干。学习停了,工作停了,什么都不会做了。我就在不断重复着考托福,考不过,考托福,考不过。这样下去,一月底还考不过,导师可能就不要我了。后面的事我也不敢再想下去了。

吃饱喝足后,我冷静分析了一下现状。我认为我的水平没有问题,只是这周的状态太差了。第二次考试相比第一次,写作的分数低了很多。我还是应该保持其他成绩,全力提升写作。离申请截止还有四周,最多最多还有三次考试的机会。我振作起来,又报名了一周后,也就是下个周二的考试。说来奇怪,一考完,我的精神状态就正常了,当天晚上也睡得很香。过了两天,楼下的装修声也听不到了,我的生物钟算是调整回去了。我名义上还是在正常上班,工作日的时候敲了敲代码,生活的自信又找回来了。

正常准备了一周,转眼又到了考试前一天。按惯例,最后的时间我还是在练口语。上周的事情让我心有余悸,这次一定不能怠慢口语了。可是,无论我怎么练,说起话来都磕磕巴巴的,找不回第一次考试时那种飞速组织语言、毫无顾虑地发言的状态了。我花了很多时间,才练到勉勉强强能把回答说完。

1月3日,第四次考试。我虽然睡得还不是很充足,但算是回到了那种打几个哈欠就能清醒的状态了。然而,这次又出现了一点意外情况。之前的考试,我怕考官会禁止我使用纸巾,鼻子不舒服也没有拿纸来擤鼻涕。这次考试,我的鼻子特别不舒服。我拿纸快速地擤了一下鼻涕,发现考官并没有提出异议。我的鼻子好像接收到了这个信号,开始止不住地流鼻涕。结果,我隔一两分钟就要擦一次鼻涕。本来就困,呼吸还不通畅,我阅读做到一半突然就开始手脚发麻。还好我脑子一直是清醒的,除了中间略有卡壳外,大部分阅读题我都能确定地答出。做到听力时,我的鼻子总算正常了,考试开始平稳进行。口语考试时,我的状态也和考前一样,非常一般,只是恰好能够把题目答完。由于之前做过足够的练习,我做综合写作倒是十分顺利。可是,做独立写作时,又不太对劲了。题面是「乡村生活和城市生活哪个更好」。我想出了三个城市好的理由,却没太想好怎么论证,写了几个不太自然的示例。

考完,又是要提前看分。经历丰富的我面对查分已经没有了任何激动,只是平淡地一步一步点进了查分界面。这次,阅读28,听力21。阅读的发挥是符合预期的,但这个听力是怎么回事?怎么多睡了几个小时还没有只睡两小时做的听力分数高?

经历了太多太多,我已经麻木了。看到这种分数,我既没有开心,也没有难过。我只想问,接下来该怎么办?越花时间准备,听力考的分数越低。亏我还是写过托福听力攻略的人。反正花时间在准备考试上也没有提升,不如多玩一玩,那所以说我的懈怠也是合理的。托福考试嘛,其实就是看运气。包括两年前的那次在内,5次考试,我听力只有两次上了25分。我的听力水平其实很差。碰到运气好了,刚好听清了,分数就很高。多数(>50%)情况下,分数是很低的,能反映我真实水平的。写作也是,运气好,考官心情好,给了个26分。其实我水平没有那么高。我还能做什么呢?只能多考几次,撞撞运气了。

越想我就越不服气。为什么我总是习惯于责怪自己呢?就不能把失败归咎于外界吗?我不是曾经说过,托福听力有问题,不能反映考生的真实听力水平吗?今天的身体状态难道没有影响我考试吗?我没有问题,就是客观情况有问题。

想着想着,第一种声音又冒出来了:那为什么在半数以上的考试中,我的听力成绩那么差?这不就是菜吗?考试成绩就是一切。考得差了就是自己的问题。

自我斗争了半天,我决定用一种公平的方式来审判:在下次考试中,我将尽可能排除一切客观上的不利,用最佳状态来考试。如果还是考不好,那就是我英语水平有问题;如果考得好,过去的事情就一笔勾销,我可以光明正大地把考试失利甩锅给考试状态。这时,我才发现我之前有多么愚蠢。考试准备得用不用功另说,为什么我不把应试状态调整到最好呢?睡眠不好调整,但是考试时间是可以调整的啊。我可以早上晚一小时再考,然后喝一瓶红牛拉满状态。提高应试状态比努力准备考试要简单得多,却又有效得多啊!

我再次分析起了现在的学习情况。这两周我都没练过阅读,全是在考场上边考边练。在真枪实弹的练习中,我的阅读技术已经在不经意间练得炉火纯青。这次考试我在状态极差的情况下都考了28分就是证明。而下次考试我会状态拉满,岂不是随随便便就能考个满分?这样一看,阅读就不用练了。写作练了这么久,综合写作也没有提升空间了,独立写作练得也没效果了,维持在22分以上应该没问题。剩下要练的只有听力和口语。听力我曾经拿过27分,我只要练回当时的状态就行了。至于口语,从第三次考试开始,我考试时的心态就出现了问题。所以,这次口语准备的重点是做到不要紧张。只要口语考出正常水平,不比22分低,总分也够100了。一周的准备时间是足够的。我毅然订下了下周二早上10:50的托福家考。

这一周的头几天,我把所有准备时间都放在了听力上。和以前的练习方法略有不同,现在,我只在每天最清醒的时候,用最集中的精神去做题。一旦觉得困了,注意力不集中了,我就去躺一下,绝不让状态差成为借口,也免得错太多题影响心态。调整了状态后,我的练习正确率果然高了不少。

周五,听力练得差不多后,我开始练口语。正好,第三次考试的成绩也出了。由于元旦放假,成绩出得晚了几天。在我状态最差的这次考试中,我的阅读/听力/口语/写作成绩分别是25/22/17/22,总分86,创下了最低分的纪录。这只有17分的口语分数深深扎在我的心里,导致我在练口语时,越说越紧张,越练越不利索。我只能勉强安慰自己,哪怕状态再差,写作都能考22分。如果阅读和听力能接近满分,口语22分左右就可以了。现在只要把口语练到以前的水平就行,不要有太大压力,不用练得太好。就这样,最后三天我一直在练口语。

1月10日,周二,早上10点,我懒洋洋地从床上爬起来。吃完早饭,灌下大半瓶红牛,我精神饱满地坐到了电脑前。这是第五次考试了,是拼上个人名誉的一战,绝对不容有失。熟练地通过安检后,我又一次进入了阅读题的测试。一看到满屏的文字,我下意识地打了个哈欠。我立刻为这个哈欠担心了起来:「怎么回事?我还是没睡醒吗?红牛没有用吗?」不行,我不能在这里停下!我眼睛一瞪,聚精会神地读起文章来。还好,能量饮料是有用的。自此之后,我就再也没打过哈欠,状态神勇地横扫了前两篇阅读。做到第三篇阅读时,我再次被一道题卡住了。好在这次留的时间非常充足,哪怕是在第三篇文章浪费了点时间,我还是给第四篇文章留了完整的18分钟。最后,我完美利用时间,十分确定地答完了每一道题。

之后是听力。我主观感觉自己的发挥和之前几次都一样,听力材料听得很清楚,完全看题目里的信息我是不是恰好记住了。不过,客观上来看,我的发挥确实好了不少。有一道很恶心的题是问「下面这些行为是左脑还是右脑负责的」,等于一道题问了好几个小问题。由于我的笔记非常全面,轻松地把这题回答出来了。

中途休息后,继续进行口语考试。口语是我最担心的一部分。考试还没开始,我的心脏就砰砰跳个不停。第一道独立口语题一出,我就傻眼了:「有些人喜欢早上学习,晚上工作,有些人则相反。你更倾向于哪一种呢?」哪有这种题啊?我爱早上工作早上工作,爱晚上学习晚上学习,哪有什么理由啊?15秒的准备时间里,我愣是一个字的回答也没想出来。开始录音,我只好想到什么说什么,连珠炮一般地把想到的内容全说出来。说完一看,还有一半的时间,内容太少,太空洞了。我只好勉强补了一两句。这道题做得极其糟糕,我狼狈地进入了下一题。很巧,第二道题目一念,我惊讶地发现这道题是第一次还是第二次考试的原题。这整套口语题我都做过!我又是惊喜又是担忧。惊喜的是后面的材料我都记得,怎么回答我都有数了;担忧的是我的状态十分糟糕,之前答起来毫无困难的题,这次就答得这么费劲。但不管怎么样,作为一个讲武德的人,我还是当成第一次做,好好地按流程记笔记,按现场记下的笔记作答,没有在念题目的时候就开始凭记忆把回答准备好。后面这几道口语题答得都还正常。

又到了最后的写作部分。不像之前几部分的考试,写作的时间非常充裕。不管之前再怎么紧张,在写作时多数人都能冷静下来答题。我这次就很不巧,综合写作刚开始没多久,考试程序崩溃了。电脑跳转到桌面上,我突然发现我的工作应用忘了关,右下角有个图标在一闪一闪。我连忙点开应用,回了领导一句话。正当我准备关掉应用的时候,鼠标的控制权突然被考官抢了过去,重新点到了考试程序上。我瞬间惊恐起来,担心考官会认为我在作弊。不过考官也没说什么,考试很快又继续了。我就这样战战兢兢地做完了综合写作。独立写作,题目是「有人觉得大学应该收费,有些人觉得大学的学费应该由政府承担。你支持哪种观点?」这别说拿英文了,拿中文要我半小时内写一篇看上去合理的文章都很难。看到这个题目,第一想法是如果学费太高,穷人就上不起学了。可是,这只能支撑一个观点,字数一定写不够。我只好选择支持大学收费,让大学收费和高质量的教育关联起来,再分别谈高质量的教育对个人、学校、政府的好处,勉强凑够了分论点。

考试结束,第五次客观题开奖了。我感觉这次阅读做得很好,但听力还是有几道不确定的题。失败这么多次了,也习惯了,考得差就差了吧。于是,我不抱希望地点进了查分页面。接下来,我看到了两个惊人的分数:阅读30,听力28。喜悦感瞬间传遍了我的全身。要是几周之前,我肯定就立刻开心地跳起来了。可是,有了上次半场开香槟的经历,伴随着喜悦感的,是一种更沉重的恐惧感。无数的负面想法扑面而来:我的口语可是考过17分的,要是这次再考个17分怎么办?考官要是把我中途点开工作应用的行为当成作弊,取消考试成绩怎么办?就这样,看着刚考出来的高分,我压根不敢庆祝,只是呆呆地坐着。十分钟后,我突然感到全身乏力。我这才想起,现在已经是下午3点了。我已经非常饿,非常疲惫了,完全是靠能量饮料才撑到了现在。

接下来的几天,我没有再准备考试了。做其他事也提不起劲,于是我一有空就会来算分。第四次考试的分数是在周三,也就是考试后第八天出的。阅读/听力/口语/写作的分数分别是28/21/20/23,总分92。如果口语和写作的分数还是和上周一样,那总分就有101分,够了。可是,上周我口语还只是考了20分。这说明我的口语还是没找回状态。万一口语再考一次17分,分数不就不够了吗?就是这样,我用着各种方式去估计第五次考试的分数。从理性上思考,我上100分是很稳的。但是,有了第二次99分的经历,我再也不敢有乐观的想法了。

一边算着分,我还一边估计着下次出分的时间。第一次出分用了九天,第二次和第四次用了八天,第三次因为假期的原因用了十天。可见,八天或九天出分是最可能的。另外,每次通知出分的邮件都是8:05左右发的,意味着早上8点大概就可以看到分数。根据这些分析,我准备从下周三早上开始,每天早起看看有没有出分。

在周三之前,我的日常生活仿佛冻结了起来。我就像沙漏里的沙子,我的生活意义就只是单纯地见证时间的流逝。周三清晨,我一直睡得不踏实。到了早上7点,我在没有闹钟的情况下突然清醒了过来,死盯着邮箱的界面,等着8点到来。7点1分,……,8点1分,8点2分,……。等到了9点,分数还是没有出。我知道今天是不会出分了,如断了弦一般躺下并睡了过去。

又在紧张中度过了一整天后,周四清晨,我以同样的方式在7点的时候醒了过来。如我所料,8点,出分了。这次,我没有看错,精准地找到了本场考试的总分——103分。阅读/听力/口语/写作分别是30/28/22/23。如果阅读和听力的发挥稍差,和第二次考试一样,总分也有100分;如果口语和上次一样只有20,总分也有101分。可以说,我的托福是实打实地有100分的水平。哪怕有几门发挥较差,也不影响总分超过100分。反复确认了考试分数后,我总算是松了一口气。在我的「个人名誉审判」上,我总算证明了自己的实力,总算有资格把以前的失败归咎于状态。考完了,胜利了,说什么都是对的。

周六是1月21日,除夕。我有幸能快快乐乐地过一个春节。1月31日之前,我考了五次托福,总算考出了100分以上的成绩。整个考试准备过程可以说是非常狼狈。如果身边有别人说他考托福考了5次才考到100,我口头上或许会加油打气,但内心里肯定会把他嘲笑一番。但我还是想把这段不是那么光彩的经历分享出来。我认为这段经历可以帮助到很多人。

经验

虽然我在高考之后只经历过托福这一种大考,但我现在也很能理解参加其他考试(比如考研)的考生。没有老师天天监督指导,完全凭自己去找学习方向,这确实不是一件简单的事。很多情况下,有效的学习时间不多,调整学习状态(娱乐)的时间反而会更多。这些情况都是正常的,人人都是这样过来的。有害的不是不学习本身,而是自我指责的态度。负罪感会带来压力,不当的压力才会给正常学习带来阻碍。因此,准备考试时要克服的不是想要休息的心态,而是自己给自己过度施加的压力。不管怎么样,最后考完了,考好了,没人会管你过去做了什么。我复习考试的状态也非常糟糕,经常拖拖拉拉。最后考五次也勉强熬过来了。我觉得很多人的准备情况是比我好的,调整完了心态,肯定能够比我更顺利地考完考试。

自己准备考试时,最重要的是要能够自测,随时得知自己当前的水平。只有这样,你才能知道自己的复习有没有用,成绩有没有提升。另外,最好是能够把要学习的内容用进度表示,比如看多少门的内容,背多少单词。进度的完成能够鼓舞自己。最后,如果有条件,可以和准备同一门考试的同学一起学习。大家不要攀比学习进度,而只是一起分享学习的过程,聊聊天,分担一下压力。如果心态正常,能够随时检测能力是否提升,自学考试不是什么难事。

当然,托福考试有一点特殊。托福本身的学习内容不多,不像其他考试要花很多时间去学你以前从没有学过的知识。从这点来看,托福似乎可以快速地准备好。但是,托福的水平难以测试。托福有一半的内容是主观题,作为一个考生,你是不知道这些题目是怎么批改的。因此,你无法知道自己当前的水平,也难以知道自己的复习是否能提升成绩。没有可以用进度表示的内容,没有自测方式,很多时候你根本不知道怎么复习是有效的。准备托福就像做优化函数求不了导数的优化问题,正常算法根本行不通。

对于自学托福的考生,我还有更多的经验想要分享。不谈准备考试时的认真程度,如果我提前知道了这些经验,我的准备肯定会更加顺利。我非常想把这些经验分享出来。

自学托福的注意事项

  1. 不要在第一次考试前估计自己的托福水平。在我的备考经历中,我曾经用两年前的托福成绩以及这两年英语水平的提升来预测我现在的托福水平。事实证明,我高估了自己的提升,低估了考试的要求。如果没有天天讲英语的环境,口语和听力是不会大幅提升的。而如果又没有经常写英文论文的话,写作水平也不会有太大提升。因此,对于国内大多数理工科学生,只有阅读水平会在大学期间得到提升。另外,哪怕你英语各方面都学得很好,你也要花一些时间去学习应试技巧,把英语水平转换成托福成绩。总之,在实际考出一次成绩之前,你是难以估计自己的托福水平的。最好多留一点时间,在准备得差不多的时候就可以不带压力地先考一次试,看看自己的水平。
  2. 自学口语和写作是很难的。阅读和听力都是客观题,你可以靠对答案来不断反思提升。但是,如果完全不借助外界的帮助,你是难以评估自己的口语和写作水平的。无法评估自己的水平,自然也就难以找到提升的方向。因此,如果想完全自学上100分,一定要把主要时间花在阅读和听力上,这两门要尽可能做到满分。还有时间可以去练写作。至于口语,我看了很多托福考试经验,多数自学托福的人都说口语不好提升。我也建议口语练到22分左右就够了。大多数学校不会要求更高的口语分数。相对地,如果非得要报班学习,最好是请别人帮你提升口语水平,再是帮你批改作文。阅读和听力则完全没有报班学习的必要,全部可以自学。
  3. 注意考试状态。托福考试不仅是在考英语,还在考短期记忆力。在做阅读时,你要大致记住每段的主要内容,这样才能快速地答完最后一道六选三;做听力时记忆就更重要了,答题时八成靠记忆,两成考笔记。口语和写作倒还好,只要笔记记得好,不需要靠脑子来记忆。因此,考试时,一定要保持清醒,保持注意力集中,这样才能记得住东西。托福和大学里的考试不太一样,不是你会就是会,不会就是不会,你还得在有限的时间里把题目答完(我觉得托福甚至比GRE还吃状态)。这一点和高考很像。考前一定要好好调整作息,拿备战高考的状态来备战托福。

总结一下自学托福的整体策略。不管你的英语水平有多高,最好是至少留三个月时间来从容地考完托福。第一个月好好准备,主要是准备能够自测的阅读和听力。写作和口语熟悉题型,保证能回答出来即可。由于不能自测,提升写作和口语的效率是极低的。哪怕不花太多时间准备,也没有关系。准备得差不多了,就可以先考一次试,测试一下水平,熟悉一下考试流程。考完了,如果成绩还差一点,第二个月再开始根据考试成绩做进一步的提升。要考100分以上,最好是阅读听力接近满分,口语22左右,写作能考多高考多高。根据考试结果,哪一部分离目标远就去着重提升哪一部分。至少预留一个月来进行考试,不要让不充足的时间给你带来压力。如果准备时间不够,想要报班,那就只报班提升口语和写作。

接下来,我再详细谈一谈托福的四个部分分别该怎么准备。

托福听说读写的提升方法

基础

网上有很多人会讲自己的备考经验。但是,如果英语基础不同,准备的策略也不同。因此,我想定义一个开始准备托福的最低英语水平要求。这个水平是大多数人开始准备托福的水平,适用性较广。如果你还没有达到这个水平,可以先打基础,再开始针对托福进行准备。

要学习托福,听、读、写能力要达到高考接近满分的水平。此外,要记住大部分四六级单词。对于口语,要能够在提前准备的情况下对着PPT做英语课堂展示。

听力和写作的要求应该没什么争议。毕竟大多数大学生的听力和写作都是高考水平,再之前也记不起中学老师是怎么教听力和写作的了。

对于阅读水平,我认为除了有高考水平外,还需要多记一些单词。不必去背托福单词,背完大部分四六级单词就行了。四六级单词其实是英语里的常用词。如果你熟悉了四六级单词,那么你阅读日常的英语文章是不会有单词上的问题的。剩下的一两个不会的单词,你查一查,也就记住了。我不建议去死记硬背单词,尽可能通过阅读的方式自然记忆单词。

由于各地教育水平的不同,口语很难找到一个统一的基础。在我看来,开始练托福口语之前,只要会用英语表达意思就够了。哪怕用的单词很简单,或者某些单词不会,都没关系。这种口语的基础可以通过英语演讲水平反映出来。大学里的英语课是会要求做简单的演讲或对话的,如果你能在认真准备后完成课堂展示(不必是即兴的),英语的口语基础应该就没有问题。

阅读

要做好托福阅读,要从语法、单词、阅读方法、应试技巧这几方面学习。

英语的语法不难,高中学的语法够用了。哪怕什么定语从句、宾语从句的具体规则都不记得了,也没关系,看几篇文章就回忆起来了。问题是,托福的阅读文章中有不少长句。常常有句子刚说到一半,插入一个逗号,接个从句,又接回之前那句话。这导致我们读着读着就忘掉句子前半部分在讲什么了。刚接触托福阅读时,是会碰到这种句子语法结构分析不清楚的问题。这时,可以花点时间去分析一下这些长句。只要弄懂句子主干是什么、哪个从句是修饰哪个词即可,不用过分深究语法细节。多读几个这种句子后,语法上就不会有问题了。

如前文所述,开始做托福阅读之前,有四六级的单词基础即可。剩下的单词可以通过阅读自然地学习。除去四六级的常用词外,托福的阅读材料常常会包含一些专有名词、历史名词。碰到这种看不懂的名词是非常正常的一件事,千万不要慌张,知道它们是指代某个物体即可,不懂这些单词不影响上下文的理解。如果碰到了没背过的形容词,就去查一查,稍微记一下。另外,我发现考试里的阅读单词比TPO的更简单一点,基本上都是英语里的常用词。如果做TPO碰到了较多的生词,也不必太担心。

打好了语法和单词的基础,接下来最重要的就是阅读文字的方法了。如果不限时间的话,我们当然可以一个单词一个单词去翻译,一句话一句话去理解。可是,在考试中,时间是有限的,我们必须要学会如何快速阅读英语文章。我个人认为,想读好英语文章,要能够带着自信去读,就像读中文一样。看网上的中文资讯的时候,我们肯定不会逐字逐句去读,而是很快地扫两眼,哪怕内容没看全,也大概知道文章在讲什么。读英语文章也是一样。千万不能去先把单词翻译成中文再理解,或者一边念出来一边理解。要敢于看到几个单词,就去脑中拼凑这个句子的意思,想着怎样尽快把文章读完。保持这种习惯,我们看英语就能像看中文一样快速。我认为阅读方法是托福阅读中最重要的一环。我的很多阅读方法都是以前看讲GRE长难句的教材学到的。如果准备时间充足的话,可以去看看一些有关GRE阅读的书。学会那些方法后,看托福阅读就是降维打击。另外,如果你一两年后才需要考托福,可以平时就多看看本专业的论文。带着目的去看生活中的英文文章是提升阅读水平的最佳方法。

足够的英语阅读水平,并不能保证你在托福阅读中拿到高分。应试技巧,是把能力转换成分数的关键。托福的阅读题是有一些规律的。比如问「下面哪个句子能够解释文章中的句子」,那些错误的句子要么是漏了信息,要么是逻辑关系搞错了。再比如,六选三中,一些错误的选项会有「一定」、「所有」这种过于武断的描述。这些应试技巧培训班的老师应该会讲。但我觉得,这些技巧最好是自己通过大量的练习领悟,这样就不需要花精力去记忆了。另外,托福阅读题还有一个很重要的技巧:多数题都会问哪个选项描述正确。对于这些题目,你是能够逐个判断出每个选项是正确还是错误的。这样,一道题其实可以从两个角度来做:三个选项的内容文章都没提,所以它们错了;一个选项的内容文章提到了,所以它对了。如果从两个角度你都得到了同一个答案,那这题你肯定做对了。我考试中做阅读题卡住的几次,都是排除了两个选项,觉得一个选项很对,但最后那个选项没找出它错在哪里。最后一次考试出分前我之所以觉得阅读做得很好,也是因为每道题我都确认了两遍。

总结一下,托福阅读只要无脑做TPO即可。在这个过程中,你会自然地学会分析长句,学会如何跳过不会的名词。为了加快阅读速度,你可以有意识地练习一下英语速读。多做一些题目,多找一下题目中的规律,你差不多能够凭借自己学会应试技巧。练习时,别忘了练时间分配。一开始练习时可以先在完全没有时间限制的情况下尽可能提高正确率,之后再在限定的时间内尽可能完成答题。完成所有这些的练习,就足以在托福阅读中拿到高分了。对了,考试前一定要睡醒。六选三其实就是考察你记没记住文章里的内容,要去文章里确认每个描述是否被提及。如果状态不好的话,你可能就记不住文章的内容,做六选三会非常吃力。

最后再分享一下我阅读文章的节奏。托福阅读题是顺着文章内容出的。比如前两题是和第一段相关,之后的两题是和第二段相关,只有最后一道六选三是和全文相关。因此,我阅读文章时,会先阅读一整段,再看这一段的题,题目问到哪我就再把哪看一遍。做完这一段相关的题,再去读下一段。我以前也试过先读全文再去做题,但我发现阅读全文时我对文章的理解很不到位。而一边做题一边读的话,能深刻理解每一段的意思。做完除最后一题外的所有题,也差不多懂了全文的意思。

听力

我之前已经写了一篇托福听力的准备方法了。我觉得这篇文章写得很好,方法非常系统,总结非常到位。可惜,可能文章写得太花里胡哨了,似乎看的人不多。我这里再稍微总结一下那篇文章的意思。

托福听力分三部分去准备:听力能力、记忆能力、理解能力。如果听不懂文章的内容,就去用精听等方法提升翻译英语语音的能力;如果听完了总记不住,就去练习应试技巧,尝试提前预测出题点;明明听懂了题目却做不对的情况较少,如果有,多反思一下做题时的想法,看看是不是做题做得太快了,选项没读懂。练习英语听力时,先想办法把错因归为上述三类中的某一类,再集中练习。

口语

我完全理解不了口语是怎么批改的,也没资格介绍口语经验。非得说我从五次口语成绩中分析出了什么,我只能说,越紧张,口语分越低。我最自信的第一次考试口语的分数反而最高。想口语拿高分,就不要自学,寻找外界的帮助吧。

如果给我充足的时间,我会按这个顺序准备口语:一开始练独立口语,学习一下独立口语怎么想内容(比如一个话题展开两到三个角度,或者观点正着说一遍反着说一遍),保证这题有话可说。之后练后面三道口语题,熟悉题型,知道大概的套路。再之后提升后面三道题的流利度,学习高分回答的答题方式,如一般要说几句话、材料的内容要复述得多准确、如何用连接词过度。最后还有时间,再去练习独立口语,准备一些常用的段子,尽可能减少说不出话的情况。离考试越近,越要多练口语,让自己在考试的时候充满信心。

写作

综合写作比较简单。综合写作会先给一段材料,材料中有一个观点的三个分论点。用三分钟读完材料后,会有一段听力,逐个反驳材料中的三个分论点。听完听力,在可以看到材料的情况下复述听力内容。相比听力测试,这段语音的理解难度很低。演讲人语速很慢,且关键处会反复重复。因此,在做听力笔记时,完全可以把演讲者观点中的关键词全部记下。写作时,只要把关键词串成句子即可。平时练习个一两次,学一学套路,掌握个简单的模板,知道怎么用连接词把句子串起来,考试的时候综合写作做起来就会非常轻松。

多数人讨论托福阅读都是在讨论独立写作。独立写作就完全是根据一个话题自主写作了。独立写作有两大要点:文章内容如何组织、英语用法是否地道。

文章的组织方式可以在短期内学会。准备写作时,尤其是在备考时间不够时,要把重点放在这上面。托福独立写作常常会给一个大家都听过,但不是那么好写的话题。想要漂漂亮亮地在半小时内论证这种话题,哪怕是拿中文,都是不太可能的。不过还好,托福写作考察的重点不是内容是否合理,而是英语表达能力是否过关。因此,考前要多花时间学习展开话题的技巧,知道怎么用恰到好处的废话来完成一篇英语写作。

组织文章内容,又需要考虑三个方面:文章整体结构、每段的分论点、段内句子展开方式。

文章整体结构,指的是你是无条件支持或反对题目的观点,还是一半肯定还是一半否定,还是其他什么形式。有非常多的托福写作教程会讲文章整体结构的分类。我反正只会一种叫做「一边倒」的方法,无条件支持或反对题目的观点,然后每一段都强烈地支持我的观点。这种方式最简单粗暴。其他常用的还有两段肯定,最后一个主体段委婉地让步一下。建议大家选择某一种组织方式,平时练习和考试都只用这种方式。

想每段的分论点,即怎么把话题展开,分成多个方面讨论,让自己有话可说。就比如我正在写的这篇托福经验分享,在介绍托福各个部分的准备方法时,我会讲这个部分要从哪几个方面学习,再详细谈每个方面具体怎么做。问题是,平时写文章,我都是有了要表达的东西,再用清楚的逻辑去写作。而托福写作是逼着你去对着一个话题,在三十分钟内组织一篇空洞无比的文章。我都没有想写的内容,只能生搬硬凑了。所以在我看来,想托福写作的分论点,就是要大开脑洞,能怎么扯怎么扯,只要是和主题相关的分论点即可。各个分论点之间是否有着密切的逻辑关系?没人会在意这个。不同题材,有不同组织分论点的方式。我多数情况是对着题目临场想分论点。我唯一掌握的一种模板,就是分个人、组织(公司或大学)、政府分别讨论。建议平时练习独立写作时,可以多看几个题目,不用写作,就想一想分论点。之后再看看范文的分论点,学一学别人是怎么组织的。

想好了文章结构和分论点,还要知道每一段的内容具体该怎么写。其中,开头结尾的组织方式和主体段的组织方式不同。开头结尾是有套路的,可以提前准备。比如开头先复述一下材料,讲一讲多数人的观点,最后坚决地提出自己的观点;结尾的时候总结一下每个主体段,再重申自己的观点。这两段不用写得很花哨,不是很难。最难的是各个主体段的写法。主体段一般是三个,每段你要用将近100词来详细阐述分论点。没有足够的备考经验的话,很容易碰到主体段没东西可写的情况。我认为,靠自己反复写作是提升不了写主体段的能力的。最好多参考别人的范文和一些教辅书,学习一下段内的叙述方式。一般来说,一个主体段,可以先提出观点,再用一两句话去详细解释,最后用示例丰富观点。句子之间多用连接词,多构建因果、递进、比较等关系。示例不用像高考语文一样非得用名人示例,也可以用生活中的普遍情况或者自己的例子。介绍这种写作技巧的书和资料非常多,建议大家去找一下,一定要学习别人的经验,不要闷头反复写文章。

组织段内内容时,有一些赖皮的方法。比如,可以一句话,同样的意思,颠来倒去地说。再比如,可以准备一些万能的句子,放到文章的开头结尾,或者某一主体段中。甚至文章话题不难的话,你可以把提前背好的一整段都搬过来。

举个例子。 “xxx plays a crucial role in our daily life. That is to say, xxx is important to us”。我见过诸如此类看上去华丽,但全是废话的句子。没办法,这种优美的句子确实好使,只要你不是全文都用这种废话就没事。

我对这些方法不做评价。反正托福是一个功利性很强的考试,能考高分怎么做都是对的。如果你时间有多,可以去背一背做一点准备。

除了文章的内容,剩下要考虑的就是如何让表达更加地道。比如,“孩子是祖国的花朵”显然就不是英语里会出现的说法。妥当的语言表达是一件很难在短期内学会的事。哪怕是中文,由于互联网流行词几个月一换,几年前的文章现在看来也不是那么时尚。没有语言环境是难以真正地提高语言表达的。因此,我推荐自学的考生完全放弃提升语言表达。

当然,我最近才发现一种效率较高的自学语言表达的方式:利用ChatGPT。你把自己写的文章输入进ChatGPT,让它帮你润色,它能够在完全不改原意的情况下让文字观感大幅提升。这等于是有一个教师手把手带着你练写作,告诉你哪里怎么改,怎么提升。ChatGPT的出现完全解决了英语写作无法自测的问题。我认为,如果使用得当的话,ChatGPT完全可以代替写作老师,让学生自己就找到提升写作的方向。很可惜,ChatGPT出现时我还在闷头准备托福,没来得及深入探究ChatGPT的用法,也没有使用ChatGPT进行写作练习的测试结果。我大胆预测,多数人用ChatGPT提升自己的英语表达后,托福写作至少能高2分。

以上内容主要是我从别处学到的托福写作准备方法,可以保证这些准备的方向是有效的。接下来我再稍微谈一下我自己的托福写作考试经验,并对托福写作的准备方法做一个总结。我也很纳闷,不知道为什么第一次考试写作拿了26分,后面就再也拿不到了。现在想来,估计是因为那次的题目比较简单,我的段内表达都比较流畅。后面每次独立写作时我都感觉很不顺利。在这几次考试中,我都举了和自己相关的例子。这些例子太假了,完全没有说服力。因此,我建议按以下顺序准备托福写作:尽快选择一种文章组织方式,比如最简单的「一边倒」式,并稍微学一下开头结尾的套路。之后花主要时间学习如何想分论点和组织段内句子,要做到有话可说,能够流畅自然地写完三个主体段。最后时间有多的话,可以借助ChatGPT提高语言表达能力。顺带一提,我每次独立写作就是300多字。只要够300字,字数多少并不影响成绩。当然,你有话可说,可以写出很多内容,那再好不过了。

总结

想自学托福上100分,建议写作和听力拉满,尽力学写作,口语差不多即可。

写作对着TPO无脑练就行。除了对答案提升解题能力外,着重提升快速阅读英语文本的能力。

听力也是对着TPO练。不过,要先分析自己的弱项,不要上来就去精听。在反复的自测中,逐个提升听懂、记住、写对的能力。

综合写作练一两次熟悉题目即可,稍微准备一个简单的全文结构模板。独立写作先学文章整体结构和开头结尾的技巧,再着重学习提论点和填充段内内容的方法。可以借助ChatGPT批改作文。

听力写把独立口语练到能说的程度,再熟悉三道综合口语的题型。之后主要提升综合口语,有时间多再认真准备独立口语。说口语时一定不能紧张。

搜索引擎根本搜不出有用的托福攻略。建议上留学论坛找前人的经验分享,这些帖子里经常能够找到许多学习资源(教材、作文范文)。四个部分中,写作的技巧是最需要向别人学习的。

在这篇文章中,我动之以情,晓之以理,把我在五次考托福中学到的东西全讲出来了。从今天起,我的脑中再也不想出现「托福」这两个字了。希望大家阅读后有所收获,早日战胜托福考试。

最近,我在考托福。第一次考完,我惊讶地发现我的听力只有19分(满分30)。我两年前考的成绩都不止这么一点。这两年我也一直在接触英语,英语水平不可能退步。如果托福真的是一个合理的,能够反映考生真实水平的考试,那我不可能考出差别这么大的分数。因此,我认为,考出这么差的分数,不是我的问题,是托福考试的问题。托福考试过分强调应试技巧,而无法稳定地反映我的水平。为了证明我的观点,我不服气地宣誓道:「我要用一周时间,学习听力应试技巧,考出一个较高的听力分数。」一周后,我又考了一次托福。结果很喜人,我的听力考了27分。

在这几天里,我总结了网上的托福听力准备方法,用一套规范的算法流程把方法表达了出来。我将分享一下这一套和深度学习算法形式相似的托福听力准备方法。我会从头把方法的背景、解决方法讲清楚。哪怕你不需要准备托福考试,或者说对深度学习没那么了解,也可以把这篇文章当故事读一遍。

托福听力规则

托福听力主要涉及两类材料:对话与讲座。对话通常发生在学生与教授或学校工作人员之间,描述了校园中常见的一些讨论、询问。讲座则模拟了真实的课堂教学,讲师会对某一专业话题做简要的描述,偶尔会穿插几句学生的提问。讲座涉及的知识面很广,常常会谈及艺术、历史、生物学、地理学、心理学等领域。不过,这些讲座不会讲特别深入的内容,也不会讲过于偏离生活的概念,保证多数人都能听懂讲座的内容。

一场无加试的托福听力由两轮组成。每轮会听1段对话和1~2段讲座。对话有5题,讲座有6题。

托福听力的答题形式和国内多数英语考试不同。每段听力中,考生只有听完了听力材料后,才能看到题目。并且,只有确认提交了前一题才能答下一题。当然,可以在听的同时记笔记。

题目全是选择题。每段听力材料给3~4分钟,基本不会有时间不够的情况出现。

托福这种不允许提前看题的考试模式把考生的记忆力也变成了考察目标之一,为考试增添了不合理的难度。后续算法的诸多改进都是为了解决「记不住」这一问题。

托福应试流程

大多数人在初次接触托福听力时都会采用这一套非常直观的算法:

但是,这套算法有一个问题:无论我们的听力水平多么优秀,都不可能把材料原原本本地记忆下来。一旦有题目考察了一个我们没记住的地方,这道题就答不上来了。

因此,多数托福教程会给出一套改进的算法。这套算法把「听材料」细分成了两部分。第一部分是「语音转文字」,这是我们大脑自己训练出来的功能。理解了听力材料的内容后,我们根据一些先验知识(知道听力材料的哪些部分容易出考题),对部分内容做重点记忆。重点记忆的方法可以是竖起耳朵集中注意力,也可以是用笔记记录关键词。最后,根据重点记忆的内容和大脑中残留的其他记忆,回答问题。

这套算法出色地把托福听力分成了三个独立的子任务:语音识别、记忆、阅读理解。这三个子任务恰好对应了三种要考察的能力:听力能力、记忆能力、理解能力。把这三种能力拆开来讨论是很有必要的。如果你没有意识到自己哪种能力相对相差,盲目地做题,尝试同时提升这三项能力,那你的学习是十分低效的。后文我们将基于这一套算法讨论如何分别提升这三种能力。

错因分析

在正式开始学习之前,我们一定要诊断出自己哪方面的能力有所缺失,进而对症下药,从最差的一项开始练习。为了找出做听力的问题,我们要进行错因分析。

为了考察听力、记忆、理解这三种能力,我们固然可以用做托福听力题以外的方式去分别测试这三种能力。但是,使用其他测试方式的话,我们不能保证我们测试出来的能力恰好是托福考试要求的能力。比如,你可以拿托福阅读来测试自己的理解能力。但是,由于托福阅读的理解难度比听力的理解难度要大,哪怕你阅读做得不好,也不能说明你听力就理解不好。因此,做托福听力题这件事本身就是最好的测试方式。

可是,正如前文所述,做托福听力会同时考察三种能力。该怎么分别考察这三种能力呢?其实,使用一些巧妙的控制变量法,就可以把这些能力区分出来了。我把网上提出的各种错因分析方法加以总结融合,提出了一种「反掩码错因分析」法。

什么是「反掩码」呢?众所周知,掩码 (mask),指的是计算机中用于屏蔽其他数据的一种数据,有时也泛指屏蔽数据这一过程。那么,我提出的「反掩码」,指的就是把原本不透明的数据变得透明。

在托福听力中,听力材料是不透明的。你需要通过自己的听力能力把听力材料变成可理解的文字。如果直接把听力材料变成了阅读材料,你就可以直接根据听力原文答题。这样,做题考察的就只有理解能力,而不再考察听力和记忆能力了。

因此,为了考察自己的理解能力。可以先听一遍材料,做题,不看答案,读一遍原文,再做题。第一遍做题是正常的练习,第二次做题是控制变量。如果第二次做题还是错了很多,就说明理解能力不行;如果第一遍做题相较第二遍错了很多,就说明听力和记忆存在问题。

同理,为了进一步区分听力和记忆的问题,我们可以调整反掩码,使用更精妙的控制变量方式:先听一遍材料,做题,不看答案,再听一遍材料,再做题。在第二次听力的时候,你已经知道题目了,不存在记不住的问题。如果第二遍答题还是错了很多,就说明听力和理解有问题;如果第一遍相较第二遍错了很多,就说明记忆有问题。当然,如果你的理解已经基本没问题了,在这一步就可以排除掉理解能力的影响,直接区分听力问题和记忆问题。

综上,「反掩码错因分析」的流程如下:

  1. 找出几套听力题。先听一遍材料,做题,不看答案,读一遍原文,再做题。主要分析自己的理解是否存在问题。
  2. 理解能力最容易提升。先想办法提升理解能力。
  3. 再找出几套听力题。先听一遍材料,做题,不看答案,再听一遍材料,再做题。区分听力问题和记忆问题。
  4. 根据诊断结果,先后提升听力问题和记忆问题。

能力提升

诊断出了问题后。就应该考虑如何设计子任务分别训练这三种能力了。

听力能力

使用听力能力时,输入是英语音频,输出是大脑中可理解的英文文字。这一过程完全由大脑的本能决定,几乎不需要主观思考。因此,我们可以设计一个非常直接的子任务:使用任意一种英语音频,一边听,一边「输出」自己的听到的文字。听完音频后,比较自己的输出和原文,看看自己哪些地方没听懂。不断反思,让大脑自己提升。

网上的很多托福听力攻略会叫你去「精听」,「听写」,或者使用什么「影子跟读法」。其实,所有这些方法都只是为了提升听力水平。他们的目的都是构造一个输入音频输出文字的训练任务。只是「输出」的方式不同而已。在我看来,输出文字的方式,可以是复述,可以是听写,甚至不用讲出来,在大脑里有个印象就行。所有这些具体的方法里,我不推荐听写法,因为你大量的学习时间都会浪费在写单词上。听完一句话后,复述这句话即可。至于是一句话反复听,还是听完一整段材料,这些形式都不重要。保证你在强迫自己不断输出听到的内容即可。

另外,与其纠结听力练习的具体方法,不如去花一点时间准备恰当的听力材料。对于有高考英语听力水平的人,直接拿托福官方题目(TPO)的听力材料练就可以,类似于TED的知识分享、新闻播报也可以。看电视剧的提升可能没那么快,因为电视剧的语速较快,且通常只有日常用语,与托福听力材料的内容不符。我个人推荐去上自己专业的英文公开课,比如大名鼎鼎的「MIT线性代数」。这些公开课语速适中,用语朴素,比较容易听懂,形式与托福听力类似,还可以顺便学一下专业知识。

记忆能力

托福听力考试的记忆分两类:第一类是被动记忆,也就是你听完整段材料后自然残留在大脑里的记忆;另一类是主动记忆,是你根据以往的答题经验,对材料中你认为的重点段落的记忆(或者笔记)。被动记忆我们没法操控,只能祈求考试当天头脑清醒一点。因此,练习记忆时,主要是练习托福听力应试技巧,提高对出题点的灵敏度,并且做到在不影响听力的情况下记下笔记。

在培训班或者网上,都能找到大量的托福听力技巧。这些技巧告诉你什么地方容易出题,听到哪些词的时候做笔记之类的。但是,我觉得背技巧的效率是很低的。最好的方式是,自己去尝试发现技巧,有了一定的经验后再去和别人的技巧对比。

那么,怎么去学会洞察听力材料的出题点呢?有两个方面的事要做。一方面,做完了题目后,总结归纳常见题型,并且找题目对应的原文,总结规律;另一方面,在听材料时,多做一点笔记,做题时看看哪些笔记用到了,哪些没有。甚至,对于以前练过的材料,可以直接跳过听力,读原文,猜出题点。通过正向和反向的练习,很快就能掌握技巧。

作为示例,我来分享几个我发现的技巧。

  • 对话材料90%会问对话发起的原因是什么。注意,这道题不是问整段对话的主题是什么,而是问学生为什么找老师或老师为什么找学生。很可能两个人寒暄了老半天,学生突然支支吾吾地说出自己过来的原因,然后话题一下就飞走了。如果你没有预判出题点,很可能这一句话就被你忽略了。
  • 讲座有时会先讲理论,再讲示例。题目会问这个示例是揭示哪个理论。如果示例里恰好冒出两个专有名词,你又不知道这里会出题,这个例子就会被你忘掉了。因此,听到示例,赶快把示例里涉及的几个名词记下来。
  • 讲座中,老师和学生都可能会表达对于某一理论的态度。尤其是结尾,老师很可能冷不丁地说「虽然这个理论很有名,但我并不是很认可」。题目很喜欢考这些态度。你可能辛辛苦苦听了4分钟的材料,想着听力要结束了,可以放松一下了。结果最后这两句很重要的态度就被你忽略了。

找到材料中的出题点,其实只是比较初级的技巧。如果你想成为高手,还有一种高阶的,更一般的方法。我把它称之为「基于注意力的记忆法」。听力材料,甚至是阅读材料,以及我们生活中各种各样的文字讯息,都是有很多废话的。如果你知道哪些话言之无物,你就可以过滤掉这些话,而集中记忆那些有意义的话。更进一步,如果你知道听力中哪些话一定不会出题,那么这些话也可以过滤掉。你可以只把注意力集中在哪些容易出题的有意义的话上。

比如说,老师讲「好久不见」、「我们上节课讲了什么什么,但是这节课……」。这些话都可以直接当成废话忽略掉。讲座时提到某个名人出生于何处,几几年在哪上学,这种过于细节的地方也不可能出题,可以忽略掉。

这种基于注意力的记忆法是很有用的。一来,它可以帮你找到重点和非重点的话,有助于合理安排精力;二来,在知道这句话不重要时,可以放心地去记上一句话的笔记。

这种记忆法不用刻意训练,多听了几段材料后就能自然地知道听力材料中哪些地方是可以忽略的。

理解能力

托福听力中的理解,既包括对材料的理解,也包含对题目和选项的理解。整体来说,托福听力对理解的要求不高。只要清楚地听懂了材料,做题时一般不会因为理解而扣分。

材料理解比较简单。只有介绍一个复杂的理科概念,或者人物表明态度时过于委婉,才可能导致材料理解错误。只要把托福阅读题练好,听力材料不可能有理解问题。

理解题目、选项时则可能会碰到一些问题。比如在主旨题中,每一个选项概括主旨都概括得不是很确切,但是其中三个选项有明显错误,一个选项概括得不全。这个时候,得选择那个概括得不全的选项。为了解决这个问题,只需要多做点题,总结记录选项理解错误导致的错题,很快避开常见的一些坑。另外,听力的时间非常充足,碰到模棱两可的选项时可以多读两遍题,千万不要把题目读错。

综合练习

TPO 30之前的题都比较简单,可以拿这些题目来反复训练各个子能力。预训练好了大脑中的各个子模块后,就可以直接开始做TPO 30之后的题,进一步综合训练所有能力。

综合练习时,应按照正常的应试流程,一边听材料,一边做笔记。听完了材料就做题,对答案。有题目做错了,大致把错因归个类,看看是三步中哪一步出了问题。根据错题的情况做进一步的查缺补漏。

在评估自己的综合练习水平时,除了对答案,还需要反思两件事:材料中的哪几句话没有听清;笔记是否记下了重要信息。

很多时候,你可能听漏了某几句话,但依然把题目都做对了。但这样并不保险,最好是把每一句话都理解透来,保证听力能力没有问题。

哪怕你已经熟悉了材料中的出题点,依然需要练习一下记笔记的方法。首先,你要保证记笔记的时机合理,不能影响听力。其次,你要保证重要的信息没有记漏。在不影响听力的前提下,多记笔记是没有坏处的。因此,反思时,只需要评估哪些题目中考察的信息是漏记的,尤其是哪些导致你做题出错的漏记的信息。

最后,再总结一下各种能力在托福听力中的重要性。理解能力是必须的,也是难度最低的。一定要把理解能力练到满分。听力能力是答题的基础,是托福听力题的主要考察对象。听力能力是最考验基本功,最难在短期提升的,一定要早做准备。记忆能力是把听力能力转换成分数的必备途径,你可能文章听得津津有味,答题时却头脑发昏。记忆能力最好提升,几天内就可以总结套路,掌握记忆关键信息的方法。

方法总结

我们来把这篇文章讲过的托福听力准备方法和应试方法整理一下。

准备考试:

  1. 做几套TPO听力,使用分析错因,找出自己较弱的能力。分析错因时,先听一遍,做题,然后在不看答案的前提下再听一遍或看一遍原文,排除出做题出错的原因。
  2. 从较弱的能力开始,逐个提升提升能力。
  3. 练得差不多了以后,综合练习,查缺补漏,同时练习应试的状态与技巧。

考试时:

  1. 听语音,在脑中把语音变成文字,判断这句话有没有重要信息。
  2. 如果觉得这句话很重要,就着重记在脑子里,或者用笔记记下。对话可以少记或不记笔记。讲座必须记下关键信息,比如举例、分类讨论、态度之类的。
  3. 根据笔记或印象答题。

托福听力题本质上还是在考察听力能力。如果你听力的基础不好,还是要把主要的时间花在打基础上。等听力水平差不多了,可以按照这篇文章介绍的内容,学习一些技巧,把听力能力转换成分数。希望大家能够考出一个不错的托福分数。

附录:托福听力框架图

我在学习今年的一篇和人脸生成相关的论文时,看到了一个约束人脸相似度的 loss。这个 loss 基于经典的 ArcFace 人脸识别网络。ArcFace 是 2019年发表的,这么久了还有人用,说明这是一篇适用性很强的工作。于是,我顺手学了一下 ArcFace 的相关背景。在这篇文章中,我将简要分享 ArcFace 人脸识别网络的发展历程,并介绍如何快速利用它的开源 PyTorch 项目计算任意两幅人脸的相似度。

人脸识别与 ArcFace

人脸识别是在深度学习时代取得较大突破的一项任务:给定一个登记了 N 个人的人脸数据库,再输入一幅人脸,输出这个人是否是 N 个人中的某一个。

多数深度学习算法会用一个 CNN 来提取所有人脸图片的特征,如果输入的图片特征和数据库里的某个特征的向量相似度大于某个阈值,就说明识别成功。也就是说,人脸识别的关键在于如何用 CNN 生成一个「好」的特征。特征的「好」体现在两点上:1) 同一个人的人脸特征要尽可能相似;2) 不同人的人脸之间的特征要尽可能不同。

为了达成这个目的,研究者提出了不同的学习特征的方法。最直观的方式是像学习词嵌入一样,用一个具体的任务来学习特征提取。恰好,人脸识别可以天然地被当成一个多分类任务:对于一个有 N 个人的人脸训练集,人脸识别就是一个 N 分类任务。只要在特征提取后面加一个线性层和一个 softmax 就可以做多分类了。训练好多分类器后,扔掉线性层和 softmax 层,就得到了一个特征提取器。

这种基于 softmax 分类器的学习方法确实能够区分训练集中的人脸,但在辨别开放人脸数据集时表现不佳。这是因为 softmax 的学习目标仅仅是区分不同类别的人脸,而没有要求这种区分有多么分明。后续的多篇工作,包括 ArcFace,都是在改进训练目标 softmax,使得每类对象之间有一个较大的间隔。

为了让不同的类别之间存在间隔,研究者们详细分析了 softmax 函数。假设$x \in \mathbb{R}^d$是人脸图片提取出的维度为$d$的特征向量,它属于$N$类里的第$y$类。softmax 前的线性层的参数是$W\in \mathbb{R}^{d\times N}, b\in \mathbb{R}^N$。则基于 softmax 的多分类误差可以写成:

其中,向量内积$W^T_jx$可以展开为:$W^T_jx=||W^T_j|| ||x|| cos \theta_j$,其中$\theta_j$是两个向量的夹角。如果对向量$W^T_j$和$x$都做归一化的话,再令$b=0$,则新的误差可以写成:

也就是说,对于这种归一化的多分类误差,对误差产生贡献的只有特征向量和$W$的列向量的夹角。那么,我们就可以换一个视角来看待这个误差:$W$其实是$N$个维度为$d$的向量的数组,它们表示了$N$个人脸类别的中心特征向量。误差要求每个特征和它对应的中心向量的夹角更小。

为了让不同的类别之间有更大的间隔,相同类别内部更加聚拢,ArcFace 在误差的$cos$中加了一个常数项。

直观上来看,加这一项就是把角度远离类别中心的惩罚扩大,使得各个类别的数据都更加靠近中心。这一优化是有效的。作者做了一个简单的实验,用8类人脸的训练集训练了一个维度为2的特征。其可视化结果如下(线是各个类别中心的方向向量,点是样本的特征向量):

ArcFace 的核心思想就是其 loss 的设计。ArcFace 的网络架构没有特别的要求,一般使用 ResNet 就行。

使用 ArcFace 开源项目

其实用这个库的时候可以完全不懂 ArcFace 的原理。

可能是由于「ArcFace」这个名字和其他项目撞车了,ArcFace 的官方 GitHub 仓库叫做 insightface。其 PyTorch 实现的网址是 https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch。

早期该论文没有官方的 PyTorch 实现。有人用 PyTorch 复现了该论文(https://github.com/TreB1eN/InsightFace_Pytorch )。我使用的项目是这个复现版,它更完整一点,包含了人脸预处理代码。

人脸识别任务最简单的例子是输出两幅人脸的相似度。接下来我将介绍如何安装这个项目,并用它来编写一个简单的计算人脸相似度的 demo。

安装这个库很方便,只要一键 git clone 就行。这个项目基本不依赖什么第三方库,环境里有 PyTorch,NumPy 等常见库就行了。

1
git clone git@github.com:TreB1eN/InsightFace_Pytorch.git

获取仓库后,在 README 的链接里下载 IR-SE50 模型 model_ir_se50.pth。随便放在哪个目录里。比如下载 IR-SE50 @ Onedrive

最后,还要准备一下测试图片。我很机智地使用了吴恩达《深度学习专项》里讲人脸识别时用到的图片,分别命名为face1.jpg, face2.jpg, face3.jpg

准备就绪后,就可以编写 demo 了。在做一个人脸相关的项目时,一般会先对人脸做预处理,让所有的人脸图片都有相同的分辨率,且五官的位置对齐。我们的 demo 也是先利用了项目中的人脸预处理库对齐人脸,再调用模型计算相似度。这三张图片的对齐结果如下:

这个demo会输出第一张人脸和第二、第三张人脸之间的相似度。完整的 demo 代码如下(注意修改其中的路径):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from model import Backbone
import torch
from PIL import Image
from mtcnn import MTCNN

from torchvision.transforms import Compose, ToTensor, Normalize

mtcnn = MTCNN()


def get_img(img_path, device):
img = Image.open(img_path)
face = mtcnn.align(img)
transfroms = Compose(
[ToTensor(), Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
return transfroms(face).to(device).unsqueeze(0)


device = 'cuda'
img1 = get_img('face1.jpg', device)
img2 = get_img('face2.jpg', device)
img3 = get_img('face3.jpg', device)

print(img1.shape)

model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se')
model.load_state_dict(torch.load('model_ir_se50.pth'))
model.eval()
model.to(device)

emb1 = model(img1)[0]
emb2 = model(img2)[0]
emb3 = model(img3)[0]
print(emb1.shape)

sim_12 = emb1.dot(emb2).item()
sim_13 = emb1.dot(emb3).item()

print(sim_12)
print(sim_13)

可能是 NumPy 更新的原因,代码在读取.npy时少了allow_pickle=True,会报错。要在mtcnn_pytorch/src/get_nets.py的几个np.load的地方加上allow_pickle=True。不知道在哪加也没关系,报错了再补上即可。

1
2
3
4
5
6
7
8
9
10
# mtcnn_pytorch/src/get_nets.py
# 约55行
weights = np.load('mtcnn_pytorch/src/weights/pnet.npy',
allow_pickle=True)[()]
# 约100行
weights = np.load('mtcnn_pytorch/src/weights/rnet.npy',
allow_pickle=True)[()]
# 约151行
weights = np.load('mtcnn_pytorch/src/weights/rnet.npy',
allow_pickle=True)[()]

修改完毕后,直接运行脚本就行了。其输出大致为:

text
1
2
3
4
torch.Size([1, 3, 112, 112])
torch.Size([512])
0.5305924415588379
0.03854911029338837

第一个输出是模型输入张量的形状。第二个输出是模型输出的特征的性质。了解这两个形状信息有助于我们调用此库。

第三个输出是第一、第二张人脸的cos相似度,第四个输出是第二、第三张人脸的cos相似度。我们已经事先知道了,第一和第二张人脸图片是同一个人,第一和第三张人脸图片是两个人。所以说,这个输出结果非常正确。

在我们自己的人脸项目中,一般都会准备好人脸预处理的代码。因此,在调用这个库时,可以只把 model.Backbone 里的代码复制过去,只使用该项目的模型即可。使用时注意输入输出的形状要求。

参考资料

ArcFace 论文:ArcFace: Additive Angular Margin Loss for Deep Face Recognition

官方仓库:https://github.com/deepinsight/insightface

PyTorch 复现仓库:https://github.com/TreB1eN/InsightFace_Pytorch

《人脸识别的 loss》:https://zhuanlan.zhihu.com/p/34404607 https://zhuanlan.zhihu.com/p/34436551

刚刚学完了PyTorch的并行训练写法,我来分享一份非常简单的PyTorch并行训练代码。希望没有学过的读者能够在接触尽可能少的新知识的前提下学会写并行训练。

完整代码 main.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

def setup():
dist.init_process_group('nccl')


def cleanup():
dist.destroy_process_group()


class ToyModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.layer = nn.Linear(1, 1)

def forward(self, x):
return self.layer(x)


class MyDataset(Dataset):

def __init__(self):
super().__init__()
self.data = torch.tensor([1, 2, 3, 4], dtype=torch.float32)

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index:index + 1]


ckpt_path = 'tmp.pth'


def main():
setup()
rank = dist.get_rank()
pid = os.getpid()
print(f'current pid: {pid}')
print(f'Current rank {rank}')
device_id = rank % torch.cuda.device_count()

dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)

model = ToyModel().to(device_id)
ddp_model = DistributedDataParallel(model, device_ids=[device_id])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

if rank == 0:
torch.save(ddp_model.state_dict(), ckpt_path)

dist.barrier()

map_location = {'cuda:0': f'cuda:{device_id}'}
state_dict = torch.load(ckpt_path, map_location=map_location)
print(f'rank {rank}: {state_dict}')
ddp_model.load_state_dict(state_dict)

for epoch in range(2):
sampler.set_epoch(epoch)
for x in dataloader:
print(f'epoch {epoch}, rank {rank} data: {x}')
x = x.to(device_id)
y = ddp_model(x)
optimizer.zero_grad()
loss = loss_fn(x, y)
loss.backward()
optimizer.step()

cleanup()


if __name__ == '__main__':
main()

假设有4张卡,使用第三和第四张卡的并行运行命令(torch v1.10 以上):

1
2
export CUDA_VISIBLE_VERSION=2,3
torchrun --nproc_per_node=2 dldemos/PyTorchDistributed/main.py

较老版本的PyTorch应使用下面这条命令(这种方法在新版本中也能用,但是会报Warning):

1
2
export CUDA_VISIBLE_VERSION=2,3
python -m torch.distributed.launch --nproc_per_node=2 dldemos/PyTorchDistributed/main.py

程序输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
current pid: 3592707
Current rank 1
current pid: 3592706
Current rank 0
rank 0: OrderedDict([('module.layer.weight', tensor([[0.3840]], device='cuda:0')), ('module.layer.bias', tensor([0.6403], device='cuda:0'))])
rank 1: OrderedDict([('module.layer.weight', tensor([[0.3840]], device='cuda:1')), ('module.layer.bias', tensor([0.6403], device='cuda:1'))])
epoch 0, rank 0 data: tensor([[1.],
[4.]])
epoch 0, rank 1 data: tensor([[2.],
[3.]])
epoch 1, rank 0 data: tensor([[2.],
[3.]])
epoch 1, rank 1 data: tensor([[4.],
[1.]])

下面来稍微讲解一下代码。这份代码演示了一种较为常见的PyTorch并行训练方式:一台机器,多GPU。一个进程管理一个GPU。每个进程共享模型参数,但是使用不同的数据,即batch size扩大了GPU个数倍。

为了实现这种并行训练:需要解决以下几个问题:

  • 怎么开启多进程?
  • 模型怎么同步参数与梯度?
  • 数据怎么划分到多个进程中?

带着这三个问题,我们来从头看一遍这份代码。

这份代码要拟合一个恒等映射y=x。使用的数据集非常简单,只有[1, 2, 3, 4]四个数字。

1
2
3
4
5
6
7
8
9
10
11
class MyDataset(Dataset):

def __init__(self):
super().__init__()
self.data = torch.tensor([1, 2, 3, 4], dtype=torch.float32)

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index:index + 1]

模型也只有一个线性函数:

1
2
3
4
5
6
7
8
class ToyModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.layer = nn.Linear(1, 1)

def forward(self, x):
return self.layer(x)

为了并行训练这个模型,我们要开启多进程。PyTorch提供的torchrun命令以及一些API封装了多进程的实现。我们只要在普通单进程程序前后加入以下的代码:

1
2
3
4
5
6
7
8
9
10
11
12
def setup():
dist.init_process_group('nccl')

def main():
setup()

...

cleanup()

def cleanup():
dist.destroy_process_group()

再用torchrun --nproc_per_node=GPU_COUNT main.py去跑这个脚本,就能用GPU_COUNT个进程来运行这个程序,每个进程分配一个GPU。我们可以用dist.get_rank()来查看当前进程的GPU号。同时,我们也可以验证,不同的GPU号对应了不同的进程id。

1
2
3
4
5
6
7
def main():
setup()
rank = dist.get_rank()
pid = os.getpid()
print(f'current pid: {pid}')
print(f'Current rank {rank}')
device_id = rank % torch.cuda.device_count()
text
1
2
3
4
5
Output:
current pid: 3592707
Current rank 1
current pid: 3592706
Current rank 0

接下来,我们来解决数据并行的问题。我们要确保一个epoch的数据被分配到了不同的进程上,以实现batch size的扩大。在PyTorch中,只要在生成Dataloader时把DistributedSampler的实例传入sampler参数就行了。DistributedSampler会自动对数据采样,并放到不同的进程中。我们稍后可以看到数据的采样结果。

1
2
3
dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)

接下来来看模型并行是怎么实现的。在这种并行训练方式下,每个模型使用同一份参数。在训练时,各个进程并行;在梯度下降时,各个进程会同步一次,保证每个进程的模型都更新相同的梯度。PyTorch又帮我们封装好了这些细节。我们只需要在现有模型上套一层DistributedDataParallel,就可以让模型在后续backward的时候自动同步梯度了。其他的操作都照旧,把新模型ddp_model当成旧模型model调用就行。

1
2
3
4
model = ToyModel().to(device_id)
ddp_model = DistributedDataParallel(model, device_ids=[device_id])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

准备好了一切后,就可以开始训练了:

1
2
3
4
5
6
7
8
9
10
for epoch in range(2):
sampler.set_epoch(epoch)
for x in dataloader:
print(f'epoch {epoch}, rank {rank} data: {x}')
x = x.to(device_id)
y = ddp_model(x)
optimizer.zero_grad()
loss = loss_fn(x, y)
loss.backward()
optimizer.step()

sampler自动完成了打乱数据集的作用。因此,在定义DataLoader时,不用开启shuffle选项。而在每个新epoch中,要用sampler.set_epoch(epoch)更新sampler,重新打乱数据集。通过输出也可以看出,数据集确实被打乱了。

text
1
2
3
4
5
6
7
8
9
Output:
epoch 0, rank 0 data: tensor([[1.],
[4.]])
epoch 0, rank 1 data: tensor([[2.],
[3.]])
epoch 1, rank 0 data: tensor([[2.],
[3.]])
epoch 1, rank 1 data: tensor([[4.],
[1.]])

大家可以去掉这行代码,跑一遍脚本,看看这行代码的作用。如果没有这行代码,每轮的数据分配情况都是一样的。

1
2
3
4
5
6
7
8
epoch 0, rank 1 data: tensor([[2.],
[3.]])
epoch 0, rank 0 data: tensor([[1.],
[4.]])
epoch 1, rank 1 data: tensor([[2.],
[3.]])
epoch 1, rank 0 data: tensor([[1.],
[4.]])

其他的训练代码和单进程代码一模一样,我们不需要做任何修改。

训练完模型后,应该保存模型。由于每个进程的模型都是一样的,我们只需要让一个进程来保存模型即可。注意,在保存模型时,其他进程不要去修改模型参数。这里最好加上一行dist.barrier(),它可以用来同步进程的运行状态。只有0号GPU的进程存完了模型,所有模型再进行下一步操作。

1
2
3
4
if rank == 0:
torch.save(ddp_model.state_dict(), ckpt_path)

dist.barrier()

读取时需要注意一下。模型存储参数时会保存参数所在设备。由于我们只用了0号GPU的进程存模型,所有参数的device都是cuda:0。而读取模型时,每个设备上的模型都要去读一次模型,参数的位置要做一个调整。

1
2
3
4
map_location = {'cuda:0': f'cuda:{device_id}'}
state_dict = torch.load(ckpt_path, map_location=map_location)
print(f'rank {rank}: {state_dict}')
ddp_model.load_state_dict(state_dict)

从输出中可以看出,在不同的进程中,参数字典是不一样的:

1
2
rank 0: OrderedDict([('module.layer.weight', tensor([[0.3840]], device='cuda:0')), ('module.layer.bias', tensor([0.6403], device='cuda:0'))])
rank 1: OrderedDict([('module.layer.weight', tensor([[0.3840]], device='cuda:1')), ('module.layer.bias', tensor([0.6403], device='cuda:1'))])

这里还有一个重要的细节。使用DistributedDataParallelmodel封装成ddp_model后,模型的参数名里多了一个module。这是因为原来的模型model被保存到了ddp_model.module这个成员变量中(model == ddp_model.module)。在混用单GPU和多GPU的训练代码时,要注意这个参数名不兼容的问题。最好的写法是每次存取ddp_model.module,这样单GPU和多GPU的checkpoint可以轻松兼容。

到此,我们完成了一个极简的PyTorch并行训练Demo。从代码中能看出,PyTorch的封装非常到位,我们只需要在单进程代码上稍作修改,就能开启并行训练。最后,我再来总结一下单卡训练转换成并行训练的修改处:

  1. 程序开始时执行dist.init_process_group('nccl'),结束时执行dist.destroy_process_group()
  2. torchrun --nproc_per_node=GPU_COUNT main.py运行脚本。
  3. 进程初始化后用rank = dist.get_rank()获取当前的GPU ID,把模型和数据都放到这个GPU上。
  4. 封装一下模型
    1
    ddp_model = DistributedDataParallel(model, device_ids=[device_id])
  5. 封装一下DataLoader
    1
    2
    3
    dataset = MyDataset()
    sampler = DistributedSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)
  6. 训练时打乱数据。sampler.set_epoch(epoch)
  7. 保存只在单卡上进行。
    1
    2
    3
    if rank == 0:
    torch.save(ddp_model.state_dict(), ckpt_path)
    dist.barrier()
  8. 读取数据时注意map_location,也要注意参数名里的module
    1
    2
    3
    map_location = {'cuda:0': f'cuda:{device_id}'}
    state_dict = torch.load(ckpt_path, map_location=map_location)
    ddp_model.load_state_dict(state_dict)

项目网址:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/PyTorchDistributed

参考资料:

  1. 官方教程:https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  2. 另一个展示简单Demo的文章:https://zhuanlan.zhihu.com/p/350301395

变分自编码器(VAE)是一类常见的生成模型。纯VAE的生成效果不见得是最好的,但VAE还是经常会被用作大模型的子模块。即使是在VAE发明多年的今天,学习VAE还是很有必要的。相比GAN等更符合直觉的模型,彻底理解VAE对数学的要求较高。在这篇文章中,我会从计算机科学的角度出发,简明地讲清楚VAE的核心原理,并附上代码实现的介绍。同时,我会稍微提及VAE是怎么利用数学知识的,以及该怎么去拓展了解这些数学知识。

用自编码器生成图像

在正式开始学习VAE之前,我们先探讨一下内容生成的几种方式,并引入自编码器(Autoencoder, AE)这个概念。为了方面描述,我们仅讨论图像的生成。

在设计生成图像的程序之前,我们要考虑一个问题——程序的输入是什么?如果程序没有任何输入,那么它就应该有一个确定的输出,也就是只能画出一幅图片。而只能画出一幅图片的程序没有任何意义的。因此,一个图像生成模型一定要有输入,用于区分不同的图片。哪怕这种输入仅仅是0, 1, 2这种序号也可以,只要模型能看懂输入,为每个输入生成不同的图片就行了。

可是,我们不仅希望不同的输入能区分不同的图片,还要让相近的输入生成相近的图片。比如1.5号图片应该长得和1号和2号相似。为了让模型满足这种性质,我们可以干脆把模型的输入建模成有意义的高维实数向量。这个向量,可以是看成对图像的一种压缩编码。比如(170, 1)就表示一幅身高为170cm的男性的照片。

绝大多数生成模型都是用这种方式对生成过程建模。所有的输入向量$z$来自于一个标准正态分布$Z$。图像生成,就是把图像的编码向量$z$解码成一幅图像的过程。不同的生成模型,只是对这个过程有着不同的约束方式。

自编码器的约束方式十分巧妙:既然把$z$翻译回图像是一个解码的过程,为什么不可以把编码的过程也加进来,让整个过程自动学习呢?如下图所示,我们可以让一个模型(编码器)学会怎么把图片压缩成一个编码,再让另一个模型(解码器)学会怎么把编码解压缩成一幅图片,最小化生成图片与原图片之间的误差。

最后,解码器就是我们需要的生成模型。只要在标准多元正态分布里采样出$z$,就可生成图片了。另外,理想情况下,$z$之间的插值向量也能代表在语义上插值的图片。

可是,由于自编码器本身的限制,这种理想不一定能实现。

自编码器的问题——过拟合

自编码器的信息压缩能力十分强大。只要编码器和解码器的神经网络足够复杂,所有训练集里的图像都可以被压缩成非常短的编码。这种编码短到什么程度了呢?——只要一个一维向量(实数)就可以描述所有训练集里的图像了。

想做到这一点并不难。还记得我们开头对生成模型的输入的讨论吗?只要让模型把所有图片以数组的形式存到编码器和解码器里,以0, 1, 2这样的序号来表示每一幅训练集里的图片,就能完成最极致的信息压缩。当然,使用这种方式的话,编码$z$就失去了所有的语义信息,编码之间的插值也不能表示图像语义上的插值了。

这是由模型过拟合导致的。如果仅使用自编码器本身的约束方式,而不加入其他正则化方法的话,一定会出现过拟合。

VAE——一种正则化的自编码器

VAE就是一种使用了某种正则化方法的自编码器,它解决了上述的过拟合问题。VAE使用的这种方法来自于概率论的变分推理,不过,我们可以在完全不了解变分推理的前提下看懂VAE。

VAE的想法是这样的:我们最终希望得到一个分布$Z$,或者说一条连续的直线。可是,编码器每次只能把图片编码成一个向量,也就是一个点。很多点是很难重建出一条连续的直线的。既然如此,我们可以把每张图片也编码成一个分布。多条直线,就可以比较容易地拼成我们想要的直线了。

当然,只让模型去拟合分布是不够的。如果各个分布都乱七八糟,相距甚远,那么它们怎么都拼不成一个标准正态分布。因此,我们还需要加上一个约束,让各个分布和标准正态分布尽可能相似。

这样,我们可以总结一下VAE的训练框架。VAE依然使用了编码器-解码器的架构。只不过,编码器的输出是一个可学习的正态分布。对分布是不可能做求导和梯度下降的,但我们可以去分布里采样,对采样出来的编码$z$解码并求导。

另外,VAE的损失函数除了要最小化重建图像与原图像之间的均方误差外,还要最大化每个分布和标准正态分布之间的相似度。

常见的描述分布之间相似度的指标叫做KL散度。只要把KL散度的公式套进损失函数里,整个训练框架就算搭好了。

如果你对KL散度的原理感兴趣,欢迎阅读我的上一篇文章:从零理解熵、交叉熵、KL散度

VAE的原理其实就是这么简单。总结一下,VAE本身是一个编码器-解码器结构的自编码器,只不过编码器的输出是一个分布,而解码器的输入是该分布的一个样本。另外,在损失函数中,除了要让重建图像和原图像更接近以外,还要让输出的分布和标准正态分布更加接近。

VAE 与变分推理

前几段其实只对VAE做了一个直觉上的描述,VAE的损失函数实际上是经严谨的数学推导得到的。如果你对数学知识不感兴趣,完全可以跳过这一节的讲解。当然,这一节也只会简单地描述VAE和变分推理的关系,更详细的数学推导可以去参考网上的其他文章。

让我们从概率论的角度看待生成模型。生成模型中的$z$可以看成是隐变量,它决定了能观测到的变量$x$。比如说,袋子里有黑球和白球,你不断地从袋子里取球出来再放回去,就能够统计出抽到黑球和白球的频率。然而,真正决定这个频率的,是袋子里黑球和白球的数量,这些数量就是观测不到的隐变量。简单来说,隐变量$z$是因,变量$x$是果。

生成模型,其实就是假设$z$来自标准正态分布,想要拟合分布$P(x|z)$(解码器),以得到$x$的分布(图像分布)。为了训练解码器,自编码器架构使用了一个编码器以描述$P(z|x)$。这样,从训练集里采样,等于是采样出了一个$x$。根据$P(z|x)$求出一个$z$,再根据$P(x|z)$试图重建$x$。优化这个过程,就是在优化编码器和解码器,也就是优化$P(z|x)$和$P(x|z)$。

然而,$P(z|x)$和$P(x|z)$之间有一个约束,它们必须满足贝叶斯公式:

假如我们要用一个和$x$有关的关于$z$的分布$Q_x(z)$去拟合$P(z|x)$,就要让$Q_x(z)$和$\frac{P(x|z)P(z)}{P(x)}$这两个分布尽可能相似。如果这个相似度是KL散度,经过一系列的推导,就可以推导出我们在VAE里使用的那个损失函数。

简单来说,拟合一个未知分布的技术就叫做变分推理。VAE利用变分推理,对模型的编码器和解码器加了一个约束,这个约束在化简后就是VAE的损失函数。

VAE和变分推理的关系就是这样。如果还想细究,可以去先学习KL散度相关的知识,再去看一下VAE中KL散度的公式推导。当然,不懂这些概念并不影响VAE的学习。

总结

VAE其实就是一个编码器-解码器架构,和U-Net以及部分NLP模型类似。然而,为了抑制自编码过程中的过拟合,VAE编码器的输出是一个正态分布,而不是一个具体的编码。同时,VAE的损失函数除了约束重建图像外,还约束了生成的分布。在这些改进下,VAE能够顺利地训练出一个解码器,以把来自正态分布的随机变量$z$画成一幅图像。

如果你想通过代码实践进一步加深对VAE的理解,可以阅读附录。

参考资料

  1. 一篇不错的VAE讲解。我是跟着这篇文章学习的。https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
  2. 我的VAE PyTorch实现参考了这个仓库:https://github.com/AntixK/PyTorch-VAE 。开头的人脸生成效果图是从这个项目里摘抄过来的。

VAE PyTorch 实现

项目网址:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/VAE

数据集

在这个项目中,我使用了CelebA数据集。这个数据集有200k张人脸,裁剪和对齐后的图片只有1个多G,对实验非常友好。

CelebA的下载链接可以在官方网站上找到:https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html。

下载好了图片后,可以用下面的代码创建Dataloader。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class CelebADataset(Dataset):
def __init__(self, root, img_shape=(64, 64)) -> None:
super().__init__()
self.root = root
self.img_shape = img_shape
self.filenames = sorted(os.listdir(root))

def __len__(self) -> int:
return len(self.filenames)

def __getitem__(self, index: int):
path = os.path.join(self.root, self.filenames[index])
img = Image.open(path).convert('RGB')
pipeline = transforms.Compose([
transforms.CenterCrop(168),
transforms.Resize(self.img_shape),
transforms.ToTensor()
])
return pipeline(img)


def get_dataloader(root='data/celebA/img_align_celeba', **kwargs):
dataset = CelebADataset(root, **kwargs)
return DataLoader(dataset, 16, shuffle=True)

这段代码是一段非常常规的根据图片路径读取图片的代码。只有少数地方需要说明:

  • 为了尽快完成demo,所有人脸图片的分辨率都是$64 \times 64$。
  • CelebA里裁剪后的人脸图片是长方形的。要先调用CenterCrop裁剪出正方形人脸,再做Resize。

为了验证Dataloader的正确性,我们可以写一些脚本来查看Dataloader里的一个batch的图片。

1
2
3
4
5
6
7
8
9
10
11
12
13
if __name__ == '__main__':
dataloader = get_dataloader()
img = next(iter(dataloader))
print(img.shape)
# Concat 4x4 images
N, C, H, W = img.shape
assert N == 16
img = torch.permute(img, (1, 0, 2, 3))
img = torch.reshape(img, (C, 4, 4 * H, W))
img = torch.permute(img, (0, 2, 1, 3))
img = torch.reshape(img, (C, 4 * H, 4 * W))
img = transforms.ToPILImage()(img)
img.save('work_dirs/tmp.jpg')

这段代码使用了一些小技巧。首先,next(iter(dataloader))可以访问Dataloader的第一个数据。其次,在把一个batch的图片转换成图片方格的过程中,我使用了比较骚的换维度、换形状操作,看起来很帅。

模型

我的VAE模型使用了类似U-Net的操作:编码器用卷积把图像的边长减半,通道翻倍,解码器用反卷积把图像的边长翻倍,通道减半。

模型结构的定义函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn


class VAE(nn.Module):
'''
VAE for 64x64 face generation. The hidden dimensions can be tuned.
'''
def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
super().__init__()

# encoder
prev_channels = 3
modules = []
img_length = 64
for cur_channels in hiddens:
modules.append(
nn.Sequential(
nn.Conv2d(prev_channels,
cur_channels,
kernel_size=3,
stride=2,
padding=1), nn.BatchNorm2d(cur_channels),
nn.ReLU()))
prev_channels = cur_channels
img_length //= 2
self.encoder = nn.Sequential(*modules)
self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim)
self.var_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim)
self.latent_dim = latent_dim
# decoder
modules = []
self.decoder_projection = nn.Linear(
latent_dim, prev_channels * img_length * img_length)
self.decoder_input_chw = (prev_channels, img_length, img_length)
for i in range(len(hiddens) - 1, 0, -1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[i],
hiddens[i - 1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU()))
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[0],
hiddens[0],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[0]), nn.ReLU(),
nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
nn.ReLU()))
self.decoder = nn.Sequential(*modules)

首先来看编码器的部分。每个卷积模块由卷积、BN、ReLU构成。卷完了再用两个全连接层分别生成正态分布的均值和方差。注意,卷积完成后,图像的形状是[prev_channels, img_length, img_length],为了把它输入到全连接层,我们到时候会做一个flatten操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# encoder
prev_channels = 3
modules = []
img_length = 64
for cur_channels in hiddens:
modules.append(
nn.Sequential(
nn.Conv2d(prev_channels,
cur_channels,
kernel_size=3,
stride=2,
padding=1), nn.BatchNorm2d(cur_channels),
nn.ReLU()))
prev_channels = cur_channels
img_length //= 2
self.encoder = nn.Sequential(*modules)
self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim)
self.var_linear = nn.Linear(prev_channels * img_length * img_length,
latent_dim)
self.latent_dim = latent_dim

解码器和编码器的操作基本完全相反。由于隐变量的维度是latent_dim,需要再用一个全连接层把图像的维度投影回[prev_channels, img_length, img_length]。之后就是反卷积放大图像的过程。写这些代码时一定要算好图像的边长,定好反卷积的次数,并且不要忘记最后把图像的通道数转换回3。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# decoder
modules = []
self.decoder_projection = nn.Linear(
latent_dim, prev_channels * img_length * img_length)
self.decoder_input_chw = (prev_channels, img_length, img_length)
for i in range(len(hiddens) - 1, 0, -1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[i],
hiddens[i - 1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU()))
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hiddens[0],
hiddens[0],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hiddens[0]), nn.ReLU(),
nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
nn.ReLU()))
self.decoder = nn.Sequential(*modules)

网络前向传播的过程如正文所述,先是用编码器编码,把图像压平送进全连接层得到均值和方差,再用randn_like随机采样,把采样的z投影、变换成正确的维度,送入解码器,最后输出重建图像以及正态分布的均值和方差。

1
2
3
4
5
6
7
8
9
10
11
12
13
def forward(self, x):
encoded = self.encoder(x)
encoded = torch.flatten(encoded, 1)
mean = self.mean_linear(encoded)
logvar = self.var_linear(encoded)
eps = torch.randn_like(logvar)
std = torch.exp(logvar / 2)
z = eps * std + mean
x = self.decoder_projection(z)
x = torch.reshape(x, (-1, *self.decoder_input_chw))
decoded = self.decoder(x)

return decoded, mean, logvar

用该模型随机生成图像的过程和前向传播的过程十分类似,只不过$z$来自于标准正态分布而已,解码过程是一模一样的。

1
2
3
4
5
6
def sample(self, device='cuda'):
z = torch.randn(1, self.latent_dim).to(device)
x = self.decoder_projection(z)
x = torch.reshape(x, (-1, *self.decoder_input_chw))
decoded = self.decoder(x)
return decoded

主函数

在主函数中,我们要先完成模型训练。在训练前,还有一件重要的事情要做:定义损失函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from time import time

import torch
import torch.nn.functional as F
from torchvision.transforms import ToPILImage

from dldemos.VAE.load_celebA import get_dataloader
from dldemos.VAE.model import VAE

# Hyperparameters
n_epochs = 10
kl_weight = 0.00025
lr = 0.005


def loss_fn(y, y_hat, mean, logvar):
recons_loss = F.mse_loss(y_hat, y)
kl_loss = torch.mean(
-0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
loss = recons_loss + kl_loss * kl_weight
return loss

如正文所述,VAE的loss包括两部分:图像的重建误差和分布之间的KL散度。二者的比例可以通过kl_weight来控制。

KL散度的公式直接去网上照抄即可。

这里要解释一下,我们的方差为什么使用其自然对数logvar。经过我的实验,如果让模型输出方差本身的话,就要在损失函数里对齐取一次自然对数。如果方差很小,趋于0的话,方差的对数就趋于无穷。这表现在loss里会出现nan。因此,在神经网络中我们应该避免拟合要取对数的数,而是直接去拟合其对数运算结果。

准备好了损失函数,剩下就是常规的训练操作了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def train(device, dataloader, model):
optimizer = torch.optim.Adam(model.parameters(), lr)
dataset_len = len(dataloader.dataset)

begin_time = time()
# train
for i in range(n_epochs):
loss_sum = 0
for x in dataloader:
x = x.to(device)
y_hat, mean, logvar = model(x)
loss = loss_fn(x, y_hat, mean, logvar)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss
loss_sum /= dataset_len
training_time = time() - begin_time
minute = int(training_time // 60)
second = int(training_time % 60)
print(f'epoch {i}: loss {loss_sum} {minute}:{second}')
torch.save(model.state_dict(), 'dldemos/VAE/model.pth')

训练好模型后,想要查看模型重建数据集图片的效果也很简单,去dataloader里采样、跑模型、后处理结果即可。

1
2
3
4
5
6
7
8
9
10
def reconstruct(device, dataloader, model):
model.eval()
batch = next(iter(dataloader))
x = batch[0:1, ...].to(device)
output = model(x)[0]
output = output[0].detach().cpu()
input = batch[0].detach().cpu()
combined = torch.cat((output, input), 1)
img = ToPILImage()(combined)
img.save('work_dirs/tmp.jpg')

想用模型随机生成图片的话,可以利用之前写好的模型采样函数。

1
2
3
4
5
6
def generate(device, model):
model.eval()
output = model.sample(device)
output = output[0].detach().cpu()
img = ToPILImage()(output)
img.save('work_dirs/tmp.jpg')

在3090上跑这个实验,100个epoch需要5个多小时。但是,模型差不多在10多个epoch的时候就收敛了。

最朴素的VAE的重建效果并不是很好,只能大概看出个脸型。这可能也和我的模型参数较少有关。

随机生成的图片也是形状还可以,但非常模糊。