0%

相比于多数图像生成模型,去噪扩散概率模型(Denoising Diffusion Probabilistic Model, DDPM)的采样速度非常慢。这是因为DDPM在采样时通常要做1000次去噪操作。但如果你玩过基于扩散模型的图像生成应用的话,你会发现,大多数应用只需要20次去噪即可生成图像。这是为什么呢?原来,这些应用都使用了一种更快速的采样方法——去噪扩散隐式模型(Denoising Diffusion Implicit Model, DDIM)。

基于DDPM,DDIM论文主要提出了两项改进。第一,对于一个已经训练好的DDPM,只需要对采样公式做简单的修改,模型就能在去噪时「跳步骤」,在一步去噪迭代中直接预测若干次去噪后的结果。比如说,假设模型从时刻$T=100$开始去噪,新的模型可以在每步去噪迭代中预测10次去噪操作后的结果,也就是逐步预测时刻$t=90, 80, …, 0$的结果。这样,DDPM的采样速度就被加速了10倍。第二,DDIM论文推广了DDPM的数学模型,从更高的视角定义了DDPM的前向过程(加噪过程)和反向过程(去噪过程)。在这个新数学模型下,我们可以自定义模型的噪声强度,让同一个训练好的DDPM有不同的采样效果。

在这篇文章中,我将言简意赅地介绍DDIM的建模方法,并给出我的DDIM PyTorch实现与实验结果。本文不会深究DDIM的数学推导,对这部分感兴趣的读者可以去阅读我在文末给出的参考资料。

回顾 DDPM

DDIM是建立在DDPM之上的一篇工作。在正式认识DDIM之前,我们先回顾一下DDPM中的一些关键内容,再从中引出DDIM的改进思想。

DDPM是一个特殊的VAE。它的编码器是$T$步固定的加噪操作,解码器是$T$步可学习的去噪操作。模型的学习目标是让每一步去噪操作尽可能抵消掉对应的加噪操作。

DDPM的加噪和去噪操作其实都是在某个正态分布中采样。因此,我们可以用概率$q, p$分别表示加噪和去噪的分布。比如 $q(\mathbf{x}_t|\mathbf{x}_{t-1})$ 就是由第 $t-1$ 时刻的图像到第 $t$ 时刻的图像的加噪声分布, $p(\mathbf{x}_{t-1}|\mathbf{x}_{t})$ 就是由第 $t$ 时刻的图像到第 $t-1$ 时刻的图像的去噪声分布。这样,我们可以说网络的学习目标是让 $p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$ 尽可能与 $q(\mathbf{x}_t | \mathbf{x}_{t-1})$ 和互逆。

但是,「互逆」并不是一个严格的数学表述。更具体地,我们应该让分布$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$和分布$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$尽可能相似。$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$和$p(\mathbf{x}_{t-1} | \mathbf{x}_{t})$的关系就和VAE中原图像与重建图像的关系一样。

$q(\mathbf{x}_{t-1} | \mathbf{x}_{t})$是不好求得的,但在给定了输入数据$\mathbf{x}_{0}$时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})$是可以用贝叶斯公式求出来的:

我们不必关心具体的求解方法,只需要知道从等式右边的三项$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)、q(\mathbf{x}_{t-1} | \mathbf{x}_0)、q(\mathbf{x}_{t} | \mathbf{x}_0)$可以推导出等式左边的那一项。在DDPM中,$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$是一个定义好的式子,且$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$。根据$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$,可以推出$q(\mathbf{x}_{t} | \mathbf{x}_0)$。知道了$q(\mathbf{x}_{t} | \mathbf{x}_0)$,$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$也就知道了(把公式里的$t$换成$t-1$就行了)。这样,在DDPM中,等式右边的式子全部已知,等式左边的$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_{0})$可以直接求出来。

上述推理过程可以简单地表示为:知道$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$,就知道了神经网络的学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$。这几个公式在DDPM中的具体形式如下:

其中,只有参数$\beta_t$是可调的。$\bar{\alpha}_t$是根据$\beta_t$算出的变量,其计算方法为:$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$。

由于学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$里只有一个未知变量$\mathbf{x}_0$,DDPM把学习目标简化成了只让神经网络根据$\mathbf{x}_{t}$拟合公式里的$\mathbf{x}_{0}$(更具体一点,是拟合从$\mathbf{x}_{0}$到$\mathbf{x}_{t}$的噪声)。也就是说,在训练时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$的公式不会被用到,只有$\mathbf{x}_{t}$和$\mathbf{x}_{0}$两个量之间的公式$q(\mathbf{x}_{t} | \mathbf{x}_0)$会被用到。只有在采样时,$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$的公式才会被用到。训练目标的推理过程可以总结为:

理解「DDPM的训练目标只有$\mathbf{x}_{0}$」对于理解DDIM非常关键。如果你在回顾DDPM时出现了问题,请再次阅读DDPM的相关介绍文章。

加速 DDPM

我们再次审视一下DDPM的推理过程:首先有$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}) = q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$。根据$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1})$,可以推出$q(\mathbf{x}_{t} | \mathbf{x}_0)$。知道$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$,由贝叶斯公式,就知道了学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$。

根据这一推理过程,DDIM论文的作者想到,假如我们把贝叶斯公式中的$t$替换成$t_2$, $t-1$替换成$t_1$,其中$t_2$是比$t_1$大的任意某一时刻,那么我们不就可以从$t_2$到$t_1$跳步骤去噪了吗?比如令$t_2 = t_1 + 10$,我们就可以求出去除10次噪声的公式,去噪的过程就快了10倍。

修改之后,$q(\mathbf{x}_{t_1} | \mathbf{x}_0)$和$q(\mathbf{x}_{t_2} | \mathbf{x}_0)$依然很好求,只要把$t_1$, $t_2$代入普通的$q(\mathbf{x}_{t} | \mathbf{x}_0)$公式里就行。

但是,$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$怎么求呢?原来的$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$来自于DDPM的定义,我们能直接把公式拿来用。能不能把$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的公式稍微修改一下,让它兼容$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$呢?

修改$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的思路如下:假如我们能把公式中的$\beta_t$换成一个由$t$和$t-1$决定的变量,我们就能把$t$换成$t_2$,$t-1$换成$t_1$,也就得到了$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$。

那怎么修改$\beta_t$的形式呢?很简单。我们知道$\beta_t$决定了$\bar{\alpha}_t$:$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$。那么我们用$\bar{\alpha}_t$除以$\bar{\alpha}_{t-1}$,不就得到了$1-\beta_t$了吗?也就是说:

我们把这个用$\bar{\alpha}_t$和$\bar{\alpha}_{t-1}$表示的$\beta_t$套入$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的公式里,再把$t$换成$t_2$,$t-1$换成$t_1$,就得到了$q(\mathbf{x}_{t_2} | \mathbf{x}_{t_1}, \mathbf{x}_0)$。有了这一项,贝叶斯公式等式右边那三项我们就全部已知,可以求出$q(\mathbf{x}_{t_1} | \mathbf{x}_{t_2}, \mathbf{x}_0)$,也就是可以一次性得到多个时刻后的去噪结果。

在这个过程中,我们只是把DDPM公式里的$\bar{\alpha}_t$换成$\bar{\alpha}_{t2}$,$\bar{\alpha}_{t-1}$换成$\bar{\alpha}_{t1}$,公式推导过程完全不变。网络的训练目标$\mathbf{x}_{0}$也没有发生改变,只是采样时的公式需要修改。这意味着我们可以先照着原DDPM的方法训练,再用这种更快速的方式采样。

我们之前只讨论了$t_1$到$t_2$为固定值的情况。实际上,我们不一定要间隔固定的时刻去噪一次,完全可以用原时刻序列的任意一个子序列来去噪。比如去噪100次的DDPM的去噪时刻序列为[99, 98, ..., 0],我们可以随便取一个长度为10的子序列:[99, 98, 77, 66, 55, 44, 33, 22, 1, 0],按这些时刻来去噪也能让采样速度加速10倍。但实践中没人会这样做,一般都是等间距地取时刻。

这样看来,在采样时,只有部分时刻才会被用到。那我们能不能顺着这个思路,干脆训练一个有效时刻更短(总时刻$T$不变)的DDPM,以加速训练呢?又或者保持有效训练时刻数不变,增大总时刻$T$呢?DDIM论文的作者提出了这些想法,认为这可以作为后续工作的研究方向。

从 DDPM 到 DDIM

除了加速DDPM外,DDIM论文还提出了一种更普遍的DDPM。在这种新的数学模型下,我们可以任意调节采样时的方差大小。让我们来看一下这个数学模型的推导过程。

DDPM的学习目标$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$由$q(\mathbf{x}_{t} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$决定。具体来说,在求解正态分布$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$时,我们会将它的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$设为未知量,并将条件$q(\mathbf{x}_{t} | \mathbf{x}_0)$、$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$、$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$代入,求解出确定的$\tilde{\mu}_t$和$\tilde{\beta}_t$。

在上文我们分析过,DDPM训练时只需要拟合$\mathbf{x}_0$,只需要用到$\mathbf{x}_0$和$\mathbf{x}_t$的关系$q(\mathbf{x}_{t} | \mathbf{x}_0)$。在不修改训练过程的前提下,我们能不能把限制$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$去掉(即$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$可以是任意一个正态分布,而不是我们提前定义好的一个正态分布),得到一个更普遍的DDPM呢?

这当然是可以的。根据基础的解方程知识,我们知道,去掉一个方程后,会多出一个自由变量。取消了$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的限制后,均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$就不能同时确定下来了。我们可以令方差$\tilde{\beta}_t$为自由变量,并让$\tilde{\mu}_t$用含$\tilde{\beta}_t$的式子表示出来。这样,我们就得到了一个方差可变的更一般的DDPM。

让我们来看一下这个新模型的具体公式。原来的DDPM的加噪声逆操作的分布为:

新的分布公式为:

新公式是旧公式的一个推广版本。如果我们把DDPM的方差$(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$代入新公式里的$\tilde{\beta}_t$,就能把新公式还原成DDPM的公式。和DDPM的公式一样,我们也可以把$\mathbf{x}_{0}$拆成$\mathbf{x}_{t}$和噪声$\epsilon$表示的式子。

现在采样时方差可以随意取了,我们来讨论一种特殊的方差取值——$\tilde{\beta}_t=0$。也就是说,扩散模型的反向过程变成了一个没有噪声的确定性过程。给定随机噪声$\mathbf{x}_{T}$,我们只能得到唯一的采样结果$\mathbf{x}_{0}$。这种结果确定的概率模型被称为隐式概率模型(implicit probabilistic model)。所以,论文作者把方差为0的这种扩散模型称为DDIM(Denoising Diffusion Implicit Model)。

为了方便地选取方差值,作者将方差改写为

其中,$\eta\in[0, 1]$。通过选择不同的$\eta$,我们实际上是在DDPM和DDIM之间插值。$\eta$控制了插值的比例。$\eta=0$,模型是DDIM;$\eta=1$,模型是DDPM。

除此之外,DDPM论文曾在采样时使用了另一种方差取值:$\tilde{\beta}_t=\beta_t$,即去噪方差等于加噪方差。实验显示这个方差的采样结果还不错。我们可以把这个取值也用到DDIM论文提出的方法里,只不过这个方差值不能直接套进上面的公式。在代码实现部分我会介绍该怎么在DDIM方法中使用这个方差。

注意,在这一节的推导过程中,我们依然没有修改DDPM的训练目标。我们可以把这种的新的采样方法用在预训练的DDPM上。当然,我们可以在使用新的采样方法的同时也使用上一节的加速采样方法。

实验

到这里为止,我们已经学完了DDIM论文的两大内容:加速采样、更换采样方差。加速采样的意义很好理解,它能大幅减少采样时间。可更换采样方差有什么意义呢?我们看完论文中的实验结果就知道了。

论文展示了新采样方法在不同方差、不同采样步数下的FID指标(越小越好)。其中,$\hat{\sigma}$表示使用DDPM中的$\tilde{\beta}_t=\beta_t$方差取值。实验结果非常有趣。在使用采样加速(步数比总时刻1000要小)时,$\eta=0$的DDIM的表现最好,而$\hat{\sigma}$的情况则非常差。而当$\eta$增大,模型越来越靠近DDPM时,用$\hat{\sigma}$的结果会越来越好。而在DDPM中,用$\hat{\sigma}$的结果是最好的。

从这个实验结果中,我们可以得到一条很简单的实践指南:如果使用了采样加速,一定要用效果最好的DDIM;而使用原DDPM的话,可以维持原论文提出的$\tilde{\beta}_t=\beta_t$方差取值。

总结

DDIM论文提出了DDPM的两个拓展方向:加速采样、变更采样方差。通过同时使用这两个方法,我们能够在不重新训练DDPM、尽可能不降低生成质量的前提下,让扩散模型的采样速度大幅提升(一般可以快50倍)。让我们再从头理一理提出DDIM方法的思考过程。

为了能直接使用预训练的DDPM,我们希望在改进DDPM时不更改DDPM的训练过程。而经过简化后,DDPM的训练目标只有拟合$\mathbf{x}_{0}$,训练时只会用到前向过程公式$q(\mathbf{x}_{t} | \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I})$。所以,我们的改进应该建立在公式$q(\mathbf{x}_{t} | \mathbf{x}_0)$完全不变的前提下。

通过对DDPM反向过程公式的简单修改,也就是把$t$改成$t_2$,$t-1$改成$t_1$,我们可以把去噪一步的公式改成去噪多步的公式,以大幅加速DDPM。可是,这样改完之后,采样的质量会有明显的下降。

我们可以猜测,减少了采样迭代次数后,采样质量之所以下降,是因为每次估计的去噪均值更加不准确。而每次去噪迭代中的噪声(由方差项决定的那一项)放大了均值的不准确性。我们能不能干脆让去噪时的方差为0呢?为了让去噪时的方差可以自由变动,我们可以去掉DDPM的约束条件。由于贝叶斯公式里的$q(\mathbf{x}_{t} | \mathbf{x}_0)$不能修改,我们只能去掉$q(\mathbf{x}_{t} | \mathbf{x}_{t - 1}, \mathbf{x}_0)$的限制。去掉限制后,方差就成了自由变量。我们让去噪方差为0,让采样过程没有噪声。这样,就得到了本文提出的DDIM模型。实验证明,在采样迭代次数减少后,使用DDIM的生成结果是最优的。

在本文中,我较为严格地区分了DDPM和DDIM的叫法:DDPM指DDPM论文中提出的有1000个扩散时刻的模型,它的采样方差只有两种取值($\tilde{\beta}_t=(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$, $\tilde{\beta}_t=\beta_t$)。DDIM指DDIM论文中提出的$\eta=0$的推广版DDPM模型。DDPM和DDIM都可以使用采样加速。但是,从习惯上我们会把没有优化加速的DDPM称为”DDPM”,把$\eta$可以任取,采样迭代次数可以任取的采样方法统称为”DDIM”。一些开源库中会有叫DDIMSampler的类,调节$\eta$的参数大概会命名为eta,调节迭代次数的参数大概会命名为ddim_num_steps。一般我们令eta=0ddim_num_steps=20即可。

DDIM的代码实现没有太多的学习价值,只要在DDPM代码的基础上把新数学公式翻译成代码即可。其中唯一值得注意的就是如何在DDIM中使用DDPM的方差$\tilde{\beta}_t=\beta_t$。对此感兴趣的话可以阅读我接下来的代码实现介绍。

在这篇解读中,我略过了DDIM论文中的大部分数学推导细节。对DDIM数学模型的推导过程感兴趣的话,可以阅读我在参考文献中推荐的文章,或者看一看原论文。

DDIM PyTorch 实现

在这个项目中,我们将对一个在CelebAHQ上预训练的DDPM执行DDIM采样,尝试复现论文中的那个FID表格,以观察不同etaddim_steps对于采样结果的影响。

代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/ddim

DDPM 基础项目

DDIM只是DDPM的一种采样改进策略。为了复现DDIM的结果,我们需要一个DDPM基础项目。由于DDPM并不是本文的重点,在这一小节里我将简要介绍我的DDPM实现代码的框架。

我们的实验需要使用CelebAHQ数据集,请在 https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256 下载该数据集并解压到项目的data/celebA/celeba_hq_256目录下。另外,我在Hugging Face上分享了一个在64x64 CelebAHQ上训练的DDPM模型:https://huggingface.co/SingleZombie/dldemos/tree/main/ckpt/ddim ,请把它放到项目的dldemos/ddim目录下。

先运行dldemos/ddim/dataset.py下载MNIST,再直接运行dldemos/ddim/main.py,代码会自动完成MNIST上的训练,并执行步数1000的两种采样和步数20的三种采样,同时将结果保存在目录work_dirs中。以下是我得到的MNIST DDPM采样结果(存储在work_dirs/diffusion_ddpm_sigma_hat.jpg中)。

为了查看64x64 CelebAHQ上的采样结果,可以在dldemos/ddim/main.py的main函数里把config_id改成2,再注释掉训练函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 0 for MNIST. See configs.py
config_id = 2
cfg = configs[config_id]
n_steps = 1000
device = 'cuda'
model_path = cfg['model_path']
img_shape = cfg['img_shape']
to_bgr = False if cfg['dataset_type'] == 'MNIST' else True

net = UNet(n_steps, img_shape, cfg['channels'], cfg['pe_dim'],
cfg.get('with_attn', False), cfg.get('norm_type', 'ln'))
ddpm = DDPM(device, n_steps)

# train(ddpm,
# net,
# cfg['dataset_type'],
# resolution=(img_shape[1], img_shape[2]),
# batch_size=cfg['batch_size'],
# n_epochs=cfg['n_epochs'],
# device=device,
# ckpt_path=model_path)

以下是我得到的CelebAHQ DDPM采样结果(存储在work_dirs/diffusion_ddpm_sigma_hat.jpg中)。

项目目录下的configs.py存储了训练配置,dataset.py定义了DataLoadernetwork.py定义了U-Net的结构,ddpm.pyddim.py分别定义了普通的DDPM前向过程和采样以及DDIM采样,dist_train.py提供了并行训练脚本,dist_sample.py提供了并行采样脚本,main.py提供了单卡运行的所有任务脚本。

在这个项目中,我们的主要的目标是基于其他文件,编写ddim.py。我们先来看一下原来的DDPM类是怎么实现的,再仿照它的接口写一个DDIM类。

实现 DDIM 采样

在我的设计中,DDPM类不是一个神经网络(torch.nn.Module),它仅仅维护了扩散模型的alpha等变量,并描述了前向过程和反向过程。

DDPM类中,我们可以在初始化函数里定义好要用到的self.betas, self.alphas, self.alpha_bars变量。如果在工程项目中,我们可以预定义好更多的常量以节约采样时间。但在学习时,我们可以少写一点代码,让项目更清晰一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class DDPM():

def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
alphas = 1 - betas
alpha_bars = torch.empty_like(alphas)
product = 1
for i, alpha in enumerate(alphas):
product *= alpha
alpha_bars[i] = product
self.betas = betas
self.n_steps = n_steps
self.alphas = alphas
self.alpha_bars = alpha_bars

前向过程就是把正态分布的公式$q(\mathbf{x}_{t} | \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_t}\mathbf{x}_{0}, (1-\bar{\alpha}_t)\mathbf{I})$翻译一下。

1
2
3
4
5
6
def sample_forward(self, x, t, eps=None):
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
if eps is None:
eps = torch.randn_like(x)
res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
return res

在反向过程中,我们从self.n_steps1枚举时刻t(代码中时刻和数组下标有1的偏差),按照公式算出每一步的去噪均值和方差,执行去噪。算法流程如下:

参数simple_var=True表示令方差$\sigma_t^2=\beta_t$,而不是$(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def sample_backward(self, img_or_shape, net, device, simple_var=True):
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
net = net.to(device)
for t in tqdm(range(self.n_steps - 1, -1, -1), "DDPM sampling"):
x = self.sample_backward_step(x, t, net, simple_var)

return x

def sample_backward_step(self, x_t, t, net, simple_var=True):

n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

接下来,我们来实现DDIM类。DDIMDDPM的推广,我们可以直接用DDIM类继承DDPM类。它们共享初始化函数与前向过程函数。

1
2
3
4
5
6
7
8
class DDIM(DDPM):

def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
super().__init__(device, n_steps, min_beta, max_beta)

我们要修改的只有反向过程的实现函数。整个函数的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def sample_backward(self,
img_or_shape,
net,
device,
simple_var=True,
ddim_step=20,
eta=1):
if simple_var:
eta = 1
ts = torch.linspace(self.n_steps, 0,
(ddim_step + 1)).to(device).to(torch.long)
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
batch_size = x.shape[0]
net = net.to(device)
for i in tqdm(range(1, ddim_step + 1),
f'DDIM sampling with eta {eta} simple_var {simple_var}'):
cur_t = ts[i - 1] - 1
prev_t = ts[i] - 1

ab_cur = self.alpha_bars[cur_t]
ab_prev = self.alpha_bars[prev_t] if prev_t >= 0 else 1

t_tensor = torch.tensor([cur_t] * batch_size,
dtype=torch.long).to(device).unsqueeze(1)
eps = net(x, t_tensor)
var = eta * (1 - ab_prev) / (1 - ab_cur) * (1 - ab_cur / ab_prev)
noise = torch.randn_like(x)

first_term = (ab_prev / ab_cur)**0.5 * x
second_term = ((1 - ab_prev - var)**0.5 -
(ab_prev * (1 - ab_cur) / ab_cur)**0.5) * eps
if simple_var:
third_term = (1 - ab_cur / ab_prev)**0.5 * noise
else:
third_term = var**0.5 * noise
x = first_term + second_term + third_term

return x

我们来把整个函数过一遍。先看一下函数的参数。相比DDPM,DDIM的采样会多出两个参数:ddim_step, eta。如正文所述,ddim_step表示执行几轮去噪迭代,eta表示DDPM和DDIM的插值系数。

1
2
3
4
5
6
7
def sample_backward(self,
img_or_shape,
net,
device,
simple_var=True,
ddim_step=20,
eta=1):

在开始迭代前,要做一些预处理。根据论文的描述,如果使用了DDPM的那种简单方差,一定要令eta=1。所以,一开始我们根据simple_vareta做一个处理。之后,我们要准备好迭代时用到的时刻。整个迭代过程中,我们会用到从self.n_steps0等间距的ddim_step+1个时刻(self.n_steps是初始时刻,不在去噪迭代中)。比如总时刻self.n_steps=100ddim_step=10ts数组里的内容就是[100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0]

1
2
3
4
5
6
7
8
9
10
if simple_var:
eta = 1
ts = torch.linspace(self.n_steps, 0,
(ddim_step + 1)).to(device).to(torch.long)
if isinstance(img_or_shape, torch.Tensor):
x = img_or_shape
else:
x = torch.randn(img_or_shape).to(device)
batch_size = x.shape[0]
net = net.to(device)

做好预处理后,进入去噪循环。在for循环中,我们从1ddim_step遍历ts的下标,从时刻数组ts里取出较大的时刻cur_t(正文中的$t_2$)和较小的时刻prev_t(正文中的$t_1$)。由于self.alpha_bars存储的是t=1, t=2, ..., t=n_steps时的变量,时刻和数组下标之间有一个1的偏移,我们要把ts里的时刻减去1得到时刻在self.alpha_bars里的下标,再取出对应的变量ab_cur, ab_prev。注意,在当前时刻为0时,self.alpha_bars是没有定义的。但由于self.alpha_bars表示连乘,我们可以特别地令当前时刻为0(prev_t=-1)时的alpha_bar=1

1
2
3
4
5
6
7
for i in tqdm(range(1, ddim_step + 1),
f'DDIM sampling with eta {eta} simple_var {simple_var}'):
cur_t = ts[i - 1] - 1
prev_t = ts[i] - 1

ab_cur = self.alpha_bars[cur_t]
ab_prev = self.alpha_bars[prev_t] if prev_t >= 0 else 1

准备好时刻后,我们使用和DDPM一样的方法,用U-Net估计生成x_t时的噪声eps,并准备好DDPM采样算法里的噪声noise(公式里的$\mathbf{z}$)。
与DDPM不同,在计算方差var时(公式里的$\sigma_t^2$),我们要给方差乘一个权重eta

1
2
3
4
5
t_tensor = torch.tensor([cur_t] * batch_size,
dtype=torch.long).to(device).unsqueeze(1)
eps = net(x, t_tensor)
var = eta * (1 - ab_prev) / (1 - ab_cur) * (1 - ab_cur / ab_prev)
noise = torch.randn_like(x)

接下来,我们要把之前算好的所有变量用起来,套入DDIM的去噪均值计算公式中。

也就是(设$\sigma_t^2 = \tilde{\beta}_t$, $\mathbf{z}$为来自标准正态分布的噪声):

由于我们只有噪声$\epsilon$,要把$\mathbf{x}_0=(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon)/\sqrt{\bar{\alpha}_t}$代入,得到不含$\mathbf{x}_0$的公式:

我在代码里把公式的三项分别命名为first_term, second_term, third_term,以便查看。

特别地,当使用DDPM的$\hat{\sigma_t}$方差取值(令$\sigma_t^2=\beta_t=\hat{\sigma_t}^2$)时,不能把这个方差套入公式中,不然$\sqrt{1-{\bar{\alpha}}_{t}-\sigma_t^2}$的根号里的数会小于0。DDIM论文提出的做法是,只修改后面和噪声$\mathbf{z}$有关的方差项,前面这个根号里的方差项保持$\sigma_t^2=(1-\bar{\alpha}_{t-1})/(1 - \bar{\alpha}_{t}) \cdot \beta_t$ ($\eta=1$)的取值。

当然,上面这些公式全都是在描述$t$到$t-1$。当描述$t_2$到$t_1$时,只需要把$\beta_t$换成$1-\frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}$,再把所有$t$换成$t_2$,$t-1$换成$t_1$即可。

把上面的公式和处理逻辑翻译成代码,就是这样:

1
2
3
4
5
6
7
8
first_term = (ab_prev / ab_cur)**0.5 * x
second_term = ((1 - ab_prev - var)**0.5 -
(ab_prev * (1 - ab_cur) / ab_cur)**0.5) * eps
if simple_var:
third_term = (1 - ab_cur / ab_prev)**0.5 * noise
else:
third_term = var**0.5 * noise
x = first_term + second_term + third_term

这样,下一刻的x就算完了。反复执行循环即可得到最终的结果。

实验

写完了DDIM采样后,我们可以编写一个随机生成图片的函数。由于DDPMDDIM的接口非常相似,我们可以用同一套代码实现DDPM或DDIM的采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def sample_imgs(ddpm,
net,
output_path,
img_shape,
n_sample=64,
device='cuda',
simple_var=True,
to_bgr=False,
**kwargs):
if img_shape[1] >= 256:
max_batch_size = 16
elif img_shape[1] >= 128:
max_batch_size = 64
else:
max_batch_size = 256

net = net.to(device)
net = net.eval()

index = 0
with torch.no_grad():
while n_sample > 0:
if n_sample >= max_batch_size:
batch_size = max_batch_size
else:
batch_size = n_sample
n_sample -= batch_size
shape = (batch_size, *img_shape)
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var,
**kwargs).detach().cpu()
imgs = (imgs + 1) / 2 * 255
imgs = imgs.clamp(0, 255).to(torch.uint8)

img_list = einops.rearrange(imgs, 'n c h w -> n h w c').numpy()
output_dir = os.path.splitext(output_path)[0]
os.makedirs(output_dir, exist_ok=True)
for i, img in enumerate(img_list):
if to_bgr:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(f'{output_dir}/{i+index}.jpg', img)

# First iteration
if index == 0:
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(batch_size**0.5))
imgs = imgs.numpy()
if to_bgr:
imgs = cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, imgs)

index += batch_size

为了生成大量图片以计算FID,在这个函数中我加入了很多和batch有关的处理。剔除这些处理代码以及图像存储后处理代码,和采样有关的核心代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def sample_imgs(ddpm,
net,
output_path,
img_shape,
n_sample=64,
device='cuda',
simple_var=True,
to_bgr=False,
**kwargs):

net = net.to(device)
net = net.eval()

with torch.no_grad():
shape = (n_sample, *img_shape)
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var,
**kwargs).detach().cpu()

如果是用DDPM采样,把参数表里的那些参数填完就行了;如果是DDIM采样,则需要在kwargs里指定ddim_stepeta

使用这个函数,我们可以进行不同ddim_step和不同eta下的64x64 CelebAHQ采样实验,以尝试复现DDIM论文的实验结果。

我们先准备好变量。

1
2
3
4
5
net = UNet(n_steps, img_shape, cfg['channels'], cfg['pe_dim'],
cfg.get('with_attn', False), cfg.get('norm_type', 'ln'))
ddpm = DDPM(device, n_steps)
ddim = DDIM(device, n_steps)
net.load_state_dict(torch.load(model_path))

第一组实验是总时刻保持1000,使用$\hat{\sigma}_t$(标准DDPM)和$\eta=0$(标准DDIM)的实验。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
sample_imgs(ddpm,
net,
'work_dirs/diffusion_ddpm_sigma_hat.jpg',
img_shape,
device=device,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddpm_eta_0.jpg',
img_shape,
device=device,
to_bgr=to_bgr,
ddim_step=1000,
simple_var=False,
eta=0)

把参数n_samples改成30000,就可以生成30000张图像,以和30000张图像的CelebAHQ之间算FID指标。由于总时刻1000的采样速度非常非常慢,建议使用dist_sample.py并行采样。

算FID指标时,可以使用torch fidelity库。使用pip即可安装此库。

1
pip install torch-fidelity

之后就可以使用命令fidelity来算指标了。假设我们把降采样过的CelebAHQ存储在data/celebA/celeba_hq_64,把我们生成的30000张图片存在work_dirs/diffusion_ddpm_sigma_hat,就可以用下面的命令算FID指标。

1
fidelity --gpu 0 --fid --input1 work_dirs/diffusion_ddpm_sigma_hat --input2 data/celebA/celeba_hq_64

整体来看,我的模型比论文差一点,总的FID会高一点。各个配置下的对比结果也稍有出入。在第一组实验中,使用$\hat{\sigma}_t$时,我的FID是13.68;使用$\eta=0$时,我的FID是13.09。而论文中用$\hat{\sigma}_t$时的FID比$\eta=0$时更低。

我们还可以做第二组实验,测试ddim_step=20(我设置的默认步数)时使用$\eta=0$, $\eta=1$, $\hat{\sigma}_t$的生成效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_sigma_hat.jpg',
img_shape,
device=device,
simple_var=True,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_eta_1.jpg',
img_shape,
device=device,
simple_var=False,
eta=1,
to_bgr=to_bgr)
sample_imgs(ddim,
net,
'work_dirs/diffusion_ddim_eta_0.jpg',
img_shape,
device=device,
simple_var=False,
eta=0,
to_bgr=to_bgr)

我的FID结果是:

1
2
3
eta=0: 17.80
eta=1: 24.00
sigma hat: 213.16

这里得到的实验结果和论文一致。减少采样迭代次数后,生成质量略有降低。同采样步数下,eta=0最优。使用sigma hat的结果会有非常多的噪声,差得完全不能看。

综合上面两个实验来看,不管什么情况下,使用eta=0,得到的结果都不会太差。

从生成速度上来看,在64x64 CelebAHQ上生成256张图片,ddim_step=20时只要3秒不到,而ddim_step=1000时要200秒。基本上是步数减少到几分之一就提速几倍。可见,DDIM加速采样对于扩散模型来说是必要的。

参考文献及学习提示

如果对DDIM公式推导及其他数学知识感兴趣,欢迎阅读苏剑林的文章:
https://spaces.ac.cn/archives/9181。

DDIM的论文为Denoising diffusion implicit models(https://arxiv.org/abs/2010.02502)。

我在本文使用的公式符号都基于DDPM论文,与上面两篇文章使用的符号不一样。比如DDIM论文里的$\alpha$在本文中是用$\bar{\alpha}$表示。

DDIM论文在介绍新均值公式时很不友好地在3.1节直接不加解释地给出了公式的形式,并在附录B中以先给结论再证明这种和逻辑思维完全反过来的方法介绍了公式的由来。建议去阅读苏剑林的文章,看看是怎么按正常的思考方式正向推导出DDIM公式。

除了在3.1节直接甩给你一个公式外,DDIM论文后面的地方都很好读懂。DDIM后面还介绍了一些比较有趣的内容,比如4.3节介绍了扩散模型和常微分方程的关系,它可以帮助我们理解为什么DDPM会设置$T=1000$这么长的加噪步数。5.3节中作者介绍了如何用DDIM在两幅图像间插值。

要回顾DDPM的知识,欢迎阅读我之前的文章:DDPM详解。

在过去的大半年里,以Stable Diffusion为代表的AI绘画是世界上最为火热的AI方向之一。或许大家会有疑问,Stable Diffusion里的这个”Diffusion”是什么意思?其实,扩散模型(Diffusion Model)正是Stable Diffusion中负责生成图像的模型。想要理解Stable Diffusion的原理,就一定绕不过扩散模型的学习。

Stable Diffusion以「毕加索笔下的《最后的晚餐》」为题的绘画结果

在这篇文章里,我会由浅入深地对最基础的去噪扩散概率模型(Denoising Diffusion Probabilistic Models, DDPM)进行讲解。我会先介绍扩散模型生成图像的基本原理,再用简单的数学语言对扩散模型建模,最后给出扩散模型的一份PyTorch实现。本文不会堆砌过于复杂的数学公式,哪怕你没有相关的数学背景,也能够轻松理解扩散模型的原理。

扩散模型与图像生成

在认识扩散模型之前,我们先退一步,看看一般的神经网络模型是怎么生成图像的。显然,为了生成丰富的图像,一个图像生成程序要根据随机数来生成图像。通常,这种随机数是一个满足标准正态分布的随机向量。这样,每次要生成新图像时,只需要从标准正态分布里随机生成一个向量并输入给程序就行了。

而在AI绘画程序中,负责生成图像的是一个神经网络模型。神经网络需要从数据中学习。对于图像生成任务,神经网络的训练数据一般是一些同类型的图片。比如一个绘制人脸的神经网络会用人脸照片来训练。也就是说,神经网络会学习如何把一个向量映射成一张图片,并确保这个图片和训练集的图片是一类图片。

可是,相比其他AI任务,图像生成任务对神经网络来说更加困难一点——图像生成任务缺乏有效的指导。在其他AI任务中,训练集本身会给出一个「标准答案」,指导AI的输出向标准答案靠拢。比如对于图像分类任务,训练集会给出每一幅图像的类别;对于人脸验证任务,训练集会给出两张人脸照片是不是同一个人;对于目标检测任务,训练集会给出目标的具体位置。然而,图像生成任务是没有标准答案的。图像生成数据集里只有一些同类型图片,却没有指导AI如何画得更好的信息。

为了解决这一问题,人们专门设计了一些用于生成图像的神经网络架构。这些架构中比较出名的有生成对抗模型(GAN)和变分自编码器(VAE)。

GAN的想法是,既然不知道一幅图片好不好,就干脆再训练一个神经网络,用于辨别某图片是不是和训练集里的图片长得一样。生成图像的神经网络叫做生成器,鉴定图像的神经网络叫做判别器。两个网络互相对抗,共同进步。

VAE则使用了逆向思维:学习向量生成图像很困难,那就再同时学习怎么用图像生成向量。这样,把某图像变成向量,再用该向量生成图像,就应该得到一幅和原图像一模一样的图像。每一个向量的绘画结果有了一个标准答案,可以用一般的优化方法来指导网络的训练了。VAE中,把图像变成向量的网络叫做编码器,把向量转换回图像的网络叫做解码器。其中,解码器就是负责生成图像的模型。

一直以来,GAN的生成效果较好,但训练起来比VAE麻烦很多。有没有和GAN一样强大,训练起来又方便的生成网络架构呢?扩散模型正是满足这些要求的生成网络架构。

扩散模型是一种特殊的VAE,其灵感来自于热力学:一个分布可以通过不断地添加噪声变成另一个分布。放到图像生成任务里,就是来自训练集的图像可以通过不断添加噪声变成符合标准正态分布的图像。从这个角度出发,我们可以对VAE做以下修改:1)不再训练一个可学习的编码器,而是把编码过程固定成不断添加噪声的过程;2)不再把图像压缩成更短的向量,而是自始至终都对一个等大的图像做操作。解码器依然是一个可学习的神经网络,它的目的也同样是实现编码的逆操作。不过,既然现在编码过程变成了加噪,那么解码器就应该负责去噪。而对于神经网络来说,去噪任务学习起来会更加有效。因此,扩散模型既不会涉及GAN中复杂的对抗训练,又比VAE更强大一点。

具体来说,扩散模型由正向过程反向过程这两部分组成,对应VAE中的编码和解码。在正向过程中,输入$\mathbf{x}_0$会不断混入高斯噪声。经过$T$次加噪声操作后,图像$\mathbf{x}_T$会变成一幅符合标准正态分布的纯噪声图像。而在反向过程中,我们希望训练出一个神经网络,该网络能够学会$T$个去噪声操作,把$\mathbf{x}_T$还原回$\mathbf{x}_0$。网络的学习目标是让$T$个去噪声操作正好能抵消掉对应的加噪声操作。训练完毕后,只需要从标准正态分布里随机采样出一个噪声,再利用反向过程里的神经网络把该噪声恢复成一幅图像,就能够生成一幅图片了。

高斯噪声,就是一幅各处颜色值都满足高斯分布(正态分布)的噪声图像。

总结一下,图像生成网络会学习如何把一个向量映射成一幅图像。设计网络架构时,最重要的是设计学习目标,让网络生成的图像和给定数据集里的图像相似。VAE的做法是使用两个网络,一个学习把图像编码成向量,另一个学习把向量解码回图像,它们的目标是让复原图像和原图像尽可能相似。学习完毕后,解码器就是图像生成网络。扩散模型是一种更具体的VAE。它把编码过程固定为加噪声,并让解码器学习怎么样消除之前添加的每一步噪声。

扩散模型的具体算法

上一节中,我们只是大概了解扩散模型的整体思想。这一节,我们来引入一些数学表示,来看一看扩散模型的训练算法和采样算法具体是什么。为了便于理解,这一节会出现一些不是那么严谨的数学描述。更加详细的一些数学推导会放到下一节里介绍。

前向过程

在前向过程中,来自训练集的图像$\mathbf{x}_0$会被添加$T$次噪声,使得$x_T$为符合标准正态分布。准确来说,「加噪声」并不是给上一时刻的图像加上噪声值,而是从一个均值与上一时刻图像相关的正态分布里采样出一幅新图像。如下面的公式所示,$\mathbf{x}_{t - 1}$是上一时刻的图像,$\mathbf{x}_{t}$是这一时刻生成的图像,该图像是从一个均值与$\mathbf{x}_{t - 1}$有关的正态分布里采样出来的。

多数文章会说前向过程是一个马尔可夫过程。其实,马尔可夫过程的意思就是当前时刻的状态只由上一时刻的状态决定,而不由更早的状态决定。上面的公式表明,计算$\mathbf{x}_t$,只需要用到$\mathbf{x}_{t - 1}$,而不需要用到$\mathbf{x}_{t - 2}, \mathbf{x}_{t - 3}…$,这符合马尔可夫过程的定义。

绝大多数扩散模型会把这个正态分布设置成这个形式:

这个正态分布公式乍看起来很奇怪:$\sqrt{1 - \beta_t}$是哪里冒出来的?为什么会有这种奇怪的系数?别急,我们先来看另一个问题:假如给定$\mathbf{x}_{0}$,也就是从训练集里采样出一幅图片,该怎么计算任意一个时刻$t$的噪声图像$\mathbf{x}_{t}$呢?

我们不妨按照公式,从$\mathbf{x}_{t}$开始倒推。$\mathbf{x}_{t}$其实可以通过一个标准正态分布的样本$\epsilon_{t-1}$算出来:

再往前推几步:

由正态分布的性质可知,均值相同的正态分布「加」在一起后,方差也会加到一起。也就是$\mathcal{N}(0, \sigma_1^2 I)$与$\mathcal{N}(0, \sigma_2^2 I)$合起来会得到$\mathcal{N}(0, (\sigma_1^2+\sigma_2^2) I)$。根据这一性质,上面的公式可以化简为:

再往前推一步的话,结果是:

我们已经能够猜出规律来了,可以一直把公式推到$\mathbf{x}_{0}$。令$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$,则:

有了这个公式,我们就可以讨论加噪声公式为什么是$\mathbf{x}_t \sim \mathcal{N}(\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$了。这个公式里的$\beta_t$是一个小于1的常数。在DDPM论文中,$\beta_t$从$\beta_1=10^{-4}$到$\beta_T=0.02$线性增长。这样,$\beta_t$变大,$\alpha_t$也越小,$\bar{\alpha}_t$趋于0的速度越来越快。最后,$\bar{\alpha}_T$几乎为0,代入$\mathbf{x}_T = \sqrt{\bar{\alpha}_T}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_T}\epsilon$, $\mathbf{x}_T$就满足标准正态分布了,符合我们对扩散模型的要求。上述推断可以简单描述为:加噪声公式能够从慢到快地改变原图像,让图像最终均值为0,方差为$\mathbf{I}$。

大家不妨尝试一下,设加噪声公式中均值和方差前的系数分别为$a, b$,按照上述过程计算最终分布的方差。只有$a^2 + b^2 = 1$才能保证最后$\mathbf{x}_T$的方差系数为1。

反向过程

在正向过程中,我们人为设置了$T$步加噪声过程。而在反向过程中,我们希望能够倒过来取消每一步加噪声操作,让一幅纯噪声图像变回数据集里的图像。这样,利用这个去噪声过程,我们就可以把任意一个从标准正态分布里采样出来的噪声图像变成一幅和训练数据长得差不多的图像,从而起到图像生成的目的。

现在问题来了:去噪声操作的数学形式是怎么样的?怎么让神经网络来学习它呢?数学原理表明,当$\beta_t$足够小时,每一步加噪声的逆操作也满足正态分布。

其中,当前时刻加噪声逆操作的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$由当前的时刻$t$、当前的图像$\mathbf{x}_{t}$决定。因此,为了描述所有去噪声操作,神经网络应该输入$t$、$\mathbf{x}_{t}$,拟合当前的均值$\tilde{\mu}_t$和方差$\tilde{\beta}_t$。

不要被上文的「去噪声」、「加噪声逆操作」绕晕了哦。由于加噪声是固定的,加噪声的逆操作也是固定的。理想情况下,我们希望去噪操作就等于加噪声逆操作。然而,加噪声的逆操作不太可能从理论上求得,我们只能用一个神经网络去拟合它。去噪声操作和加噪声逆操作的关系,就是神经网络的预测值和真值的关系。

现在问题来了:加噪声逆操作的均值和方差是什么?

直接计算所有数据的加噪声逆操作的分布是不太现实的。但是,如果给定了某个训练集输入$\mathbf{x}_0$,多了一个限定条件后,该分布是可以用贝叶斯公式计算的(其中$q$表示概率分布):

等式左边的$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1};\tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$表示加噪声操作的逆操作,它的均值和方差都是待求的。右边的$q(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t};\sqrt{1 - \beta_t}\mathbf{x}_{t - 1},\beta_t\mathbf{I})$是加噪声的分布。而由于$\mathbf{x}_0$已知,$q(\mathbf{x}_{t-1} | \mathbf{x}_0)$和$q(\mathbf{x}_{t} | \mathbf{x}_0)$两项可以根据前面的公式$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t$得来:

这样,等式右边的式子全部已知。我们可以把公式套入,算出给定$\mathbf{x}_0$时的去噪声分布。经计算化简,分布的均值为:

其中,$\epsilon_t$是用公式算$\mathbf{x}_t$时从标准正态分布采样出的样本,它来自公式

分布的方差为:

注意,$\beta_t$是加噪声的方差,是一个常量。那么,加噪声逆操作的方差$\tilde{\beta}_t$也是一个常量,不与输入$\mathbf{x}_0$相关。这下就省事了,训练去噪网络时,神经网络只用拟合$T$个均值就行,不用再拟合方差了。

知道了均值和方差的真值,训练神经网络只差最后的问题了:该怎么设置训练的损失函数?加噪声逆操作和去噪声操作都是正态分布,网络的训练目标应该是让每对正态分布更加接近。那怎么用损失函数描述两个分布尽可能接近呢?最直观的想法,肯定是让两个正态分布的均值尽可能接近,方差尽可能接近。根据上文的分析,方差是常量,只用让均值尽可能接近就可以了。

那怎么用数学公式表达让均值更接近呢?再观察一下目标均值的公式:

神经网络拟合均值时,$\mathbf{x}_{t}$是已知的(别忘了,图像是一步一步倒着去噪的)。式子里唯一不确定的只有$\epsilon_t$。既然如此,神经网络干脆也别预测均值了,直接预测一个噪声$\epsilon_\theta(\mathbf{x}_{t}, t)$(其中$\theta$为可学习参数),让它和生成$\mathbf{x}_{t}$的噪声$\epsilon_t$的均方误差最小就行了。对于一轮训练,最终的误差函数可以写成

这样,我们就认识了反向过程的所有内容。总结一下,反向过程中,神经网络应该让$T$个去噪声操作拟合对应的$T$个加噪声逆操作。每步加噪声逆操作符合正态分布,且在给定某个输入时,该正态分布的均值和方差是可以用解析式表达出来的。因此,神经网络的学习目标就是让其输出的去噪声分布和理论计算的加噪声逆操作分布一致。经过数学计算上的一些化简,问题被转换成了拟合生成$\mathbf{x}_{t}$时用到的随机噪声$\epsilon_t$。

训练算法与采样算法

理解了前向过程和反向过程后,训练神经网络的算法和采样图片(生成图片)的算法就呼之欲出了。

以下是DDPM论文中的训练算法:

让我们来逐行理解一下这个算法。第二行是指从训练集里取一个数据$\mathbf{x}_{0}$。第三行是指随机从$1, …, T$里取一个时刻用来训练。我们虽然要求神经网络拟合$T$个正态分布,但实际训练时,不用一轮预测$T$个结果,只需要随机预测$T$个时刻中某一个时刻的结果就行。第四行指随机生成一个噪声$\epsilon$,该噪声是用于执行前向过程生成$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon$的。之后,我们把$\mathbf{x}_t$和$t$传给神经网络$\epsilon_\theta(\mathbf{x}_{t}, t)$,让神经网络预测随机噪声。训练的损失函数是预测噪声和实际噪声之间的均方误差,对此损失函数采用梯度下降即可优化网络。

DDPM并没有规定神经网络的结构。根据任务的难易程度,我们可以自己定义简单或复杂的网络结构。这里只需要把$\epsilon_\theta(\mathbf{x}_{t}, t)$当成一个普通的映射即可。

训练好了网络后,我们可以执行反向过程,对任意一幅噪声图像去噪,以实现图像生成。这个算法如下:

第一行的$\mathbf{x}_{T}$就是从标准正态分布里随机采样的输入噪声。要生成不同的图像,只需要更换这个噪声。后面的过程就是扩散模型的反向过程。令时刻从$T$到$1$,计算这一时刻去噪声操作的均值和方差,并采样出$\mathbf{x}_{t-1}$。均值是用之前提到的公式计算的:

而方差$\sigma_t^2$的公式有两种选择,两个公式都能产生差不多的结果。实验表明,当$\mathbf{x}_{0}$是特定的某个数据时,用上一节推导出来的方差最好。

而当$\mathbf{x}_{0} \sim \mathcal{N}(0, \mathbf{I})$时,只需要令方差和加噪声时的方差一样即可。

循环执行去噪声操作。最后生成的$\mathbf{x}_{0}$就是生成出来的图像。

特别地,最后一步去噪声是不用加方差项的。为什么呢,观察公式$\sigma_t^2=\frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}} \cdot \beta_t$。当$t=1$时,分子会出现$\bar{\alpha}_{t-1}=\bar{\alpha}_0$这一项。$\bar{\alpha}_t$是一个连乘,理论上$t$是从$1$开始的,在$t=0$时没有定义。但我们可以特别地令连乘的第0项$\bar{\alpha}_0=1$。这样,$t=1$时方差项的分子$1-\bar{\alpha}_{t-1}$为$0$,不用算这一项了。

当然,这一解释从数学上来说是不严谨的。据论文说,这部分的解释可以参见朗之万动力学。

数学推导的补充 (选读)

理解了训练算法和采样算法,我们就算是搞懂了扩散模型,可以去编写代码了。不过,上文的描述省略了一些数学推导的细节。如果对扩散模型更深的原理感兴趣,可以阅读一下本节。

加噪声逆操作均值和方差的推导

上一节,我们根据下面几个式子

一步就给出了$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$的均值和方差。

现在我们来看一下推导均值和方差的思路。

首先,把其他几个式子带入贝叶斯公式的等式右边。

由于多个正态分布的乘积还是一个正态分布,我们知道$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$也可以用一个正态分布公式$\mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t, \tilde{\beta}_t\mathbf{I})$表达,它最后一定能写成这种形式:

问题就变成了怎么把开始那个很长的式子化简,算出$\tilde{\mu}_t$和$\tilde{\beta}_t$。

方差$\tilde{\beta}_t$可以从指数函数的系数得来,比较好求。系数为

所以,方差为:

接下来只要关注指数函数的指数部分。指数部分一定是一个关于的$\mathbf{x}_{t-1}$的二次函数,只要化简成$(\mathbf{x}_{t-1}-C)^2$的形式,再除以一下$-2$倍方差,就可以得到均值了。

指数部分为:

$\mathbf{x}_{t-1}$只在前两项里有。把和$\mathbf{x}_{t-1}$有关的项计算化简,可以计算出均值:

回想一下,在去噪声中,神经网络的输入是$\mathbf{x}_{t}$和$t$。也就是说,上式中$\mathbf{x}_{t}$已知,只有$\mathbf{x}_{0}$一个未知量。要算均值,还需要算出$\mathbf{x}_{0}$。$\mathbf{x}_{0}$和$\mathbf{x}_{t}$之间是有一定联系的。$\mathbf{x}_{t}$是$\mathbf{x}_{0}$在正向过程中第$t$步加噪声的结果。而根据正向过程的公式倒推:

把这个$\mathbf{x}_{0}$带入均值公式,均值最后会化简成我们熟悉的形式。

优化目标

上一节,我们只是简单地说神经网络的优化目标是让加噪声和去噪声的均值接近。而让均值接近,就是让生成$\mathbf{x}_t$的噪声$\epsilon_t$更接近。实际上,这个优化目标是经过简化得来的。扩散模型最早的优化目标是有一定的数学意义的。

扩散模型,全称为扩散概率模型(Diffusion Probabilistic Model)。最简单的一类扩散模型,是去噪扩散概率模型(Denoising Diffusion Probabilistic Model),也就是常说的DDPM。DDPM的框架主要是由两篇论文建立起来的。第一篇论文是首次提出扩散模型思想的Deep Unsupervised Learning using Nonequilibrium Thermodynamics。在此基础上,Denoising Diffusion Probabilistic Models对最早的扩散模型做出了一定的简化,让图像生成效果大幅提升,促成了扩散模型的广泛使用。我们上一节看到的公式,全部是简化后的结果。

扩散概率模型的名字之所以有「概率」二字,是因为这个模型是在描述一个系统的概率。准确来说,扩散模型是在描述经反向过程生成出某一项数据的概率。也就是说,扩散模型$p_{\theta}(\mathbf{x}_0)$是一个有着可训练参数$\theta$的模型,它描述了反向过程生成出数据$\mathbf{x}_0$的概率。$p_{\theta}(\mathbf{x}_0)$满足$p_{\theta}(\mathbf{x}_0)=\int p_{\theta}(\mathbf{x}_{0:T})d\mathbf{x}_{1:T}$,其中$p_{\theta}(\mathbf{x}_{0:T})$就是我们熟悉的反向过程,只不过它是以概率计算的形式表达:

我们上一节里见到的优化目标,是让去噪声操作$p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})$和加噪声操作的逆操作$q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0)$尽可能相似。然而,这个描述并不确切。扩散模型原本的目标,是最大化$p_{\theta}(\mathbf{x}_0)$这个概率,其中$\mathbf{x}_0$是来自训练集的数据。换个角度说,给定一个训练集的数据$\mathbf{x}_0$,经过前向过程和反向过程,扩散模型要让复原出$\mathbf{x}_0$的概率尽可能大。这也是我们在本文开头认识VAE时见到的优化目标。

最大化$p_{\theta}(\mathbf{x}_0)$,一般会写成最小化其负对数值,即最小化$-log p_{\theta}(\mathbf{x}_0)$。使用和VAE类似的变分推理,可以把优化目标转换成优化一个叫做变分下界(variational lower bound, VLB)的量。它最终可以写成:

这里的$D_{KL}(P||Q)$表示分布P和Q之间的KL散度。KL散度是衡量两个分布相似度的指标。如果$P, Q$都是正态分布,则它们的KL散度可以由一个简单的公式给出。关于KL散度的知识可以参见我之前的文章:从零理解熵、交叉熵、KL散度。

其中,第一项$D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) || p_\theta(\mathbf{x}_T))$和可学习参数$\theta$无关(因为可学习参数只描述了每一步去噪声操作,也就是只描述了$p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t})$),可以不去管它。那么这个优化目标就由两部分组成:

  1. 最小化$D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))$表示的是最大化每一个去噪声操作和加噪声逆操作的相似度。
  2. 最小化$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$就是已知$\mathbf{x}_{1}$时,让最后复原原图$\mathbf{x}_{0}$概率更高。

我们分别看这两部分是怎么计算的。

对于第一部分,我们先回顾一下正态分布之间的KL散度公式。设一维正态分布$P, Q$的公式如下:

而对于$D_{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_{t}, \mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_{t}))$,根据前文的分析,我们知道,待求方差$\Sigma_\theta(\mathbf{x}_{t}, t)$可以直接由计算得到。

两个正态分布方差的比值是常量。所以,在计算KL散度时,不用管方差那一项了,只需要管均值那一项。

由根据之前的均值公式

这一部分的优化目标可以化简成

DDPM论文指出,如果把前面的系数全部丢掉的话,模型的效果更好。最终,我们就能得到一个非常简单的优化目标:

这就是我们上一节见到的优化目标。

当然,还没完,别忘了优化目标里还有$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$这一项。它的形式为:

只管后面有$\theta$的那一项(注意,$\alpha_1=\bar{\alpha}_1=1-\beta_1$):

这和那些KL散度项$t=1$时的形式相同,我们可以用相同的方式简化优化目标,只保留$|| \epsilon_1-\epsilon_\theta(\mathbf{x}_{1}, 1)||^2$。这样,损失函数的形式全都是$||\epsilon_t-\epsilon_{\theta}(\mathbf{x}_{t}, t)||^2$了。

DDPM论文里写$- logp_\theta(\mathbf{x}_{0}|\mathbf{x}_{1})$这一项可以直接满足简化后的公式$t=1$时的情况,而没有去掉系数的过程。我在网上没找到文章解释这一点,只好按自己的理解来推导这个误差项了。不论如何,推导的过程不是那么重要,重要的是最后的简化形式。

总结

图像生成任务就是把随机生成的向量(噪声)映射成和训练图像类似的图像。为此,扩散模型把这个过程看成是对纯噪声图像的去噪过程。通过学习把图像逐步变成纯噪声的逆操作,扩散模型可以把任何一个纯噪声图像变成有意义的图像,也就是完成图像生成。

对于不同程度的读者,应该对本文有不同的认识。

对于只想了解扩散模型大概原理的读者,只需要阅读第一节,并大概了解:

  • 图像生成任务的通常做法
  • 图像生成任务需要监督
  • VAE通过把图像编码再解码来训练一个解码器
  • 扩散模型是一类特殊的VAE,它的编码固定为加噪声,解码固定为去噪声

对于想认真学习扩散模型的读者,只需读懂第二节的主要内容:

  • 扩散模型的优化目标:让反向过程尽可能成为正向过程的逆操作
  • 正向过程的公式
  • 反向过程的做法(采样算法)
  • 加噪声逆操作的均值和方差在给定$\mathbf{x}_{0}$时可以求出来的,加噪声逆操作的均值就是去噪声的学习目标
  • 简化后的损失函数与训练算法

对有学有余力对数学感兴趣的读者,可以看一看第三节的内容:

  • 加噪声逆操作均值和方差的推导
  • 扩散模型最早的优化目标与DDPM论文是如何简化优化目标的

我个人认为,由于扩散模型的优化目标已经被大幅度简化,除非你的研究目标是改进扩散模型本身,否则没必要花过多的时间钻研数学原理。在学习时,建议快点看懂扩散模型的整体思想,搞懂最核心的训练算法和采样算法,跑通代码。之后就可以去看较新的论文了。

在附录中,我给出了一份DDPM的简单实现。欢迎大家参考,并自己动手复现一遍DDPM。

参考资料与学习建议

网上绝大多数的中英文教程都是照搬 https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ 这篇文章的。这篇文章像教科书一样严谨,适合有一定数学基础的人阅读,但不适合给初学者学习。建议在弄懂扩散模型的大概原理后再来阅读这篇文章补充细节。

多数介绍扩散模型的文章对没学过相关数学知识的人来说很不友好,我在阅读此类文章时碰到了大量的问题:为什么前向公式里有个$\sqrt{1-\beta}$?为什么突然冒出一个快速算$\mathbf{x}_{t}$的公式?为什么反向过程里来了个贝叶斯公式?优化目标是什么?$-log p_{\theta}(\mathbf{x}_0)$是什么?为什么优化目标里一大堆项,每一项的意义又是什么?为什么最后莫名其妙算一个$\epsilon$?为什么采样时$t=0$就不用加方差项了?好不容易,我才把这些问题慢慢搞懂,并在本文做出了解释。希望我的解答能够帮助到同样有这些困惑的读者。想逐步学习扩散模型,可以先看懂我这篇文章的大概讲解,再去其他文章里学懂一些细节。无论是教,还是学,最重要的都是搞懂整体思路,知道动机,最后再去强调细节。

再强烈推荐一位作者写的DDPM系列介绍:https://kexue.fm/archives/9119 。这位作者是全网为数不多的能令我敬佩的作者。早知道有这些文章,我也没必要自己写一遍了。

这里还有篇文章给出了扩散模型中数学公式的详细推导,并补充了变分推理的背景介绍,适合从头学起:https://arxiv.org/abs/2208.11970

想深入学习DDPM,可以看一看最重要的两篇论文:Deep Unsupervised Learning using Nonequilibrium ThermodynamicsDenoising Diffusion Probabilistic Models。当然,后者更重要一些,里面的一些实验结果仍有阅读价值。

我在代码复现时参考了这篇文章。相对于网上的其他开源DDPM实现,这份代码比较简短易懂,更适合学习。不过,这份代码有一点问题。它的神经网络不够强大,采样结果会有一点问题。

附录:代码复现

在这个项目中,我们要用PyTorch实现一个基于U-Net的DDPM,并在MNIST数据集(经典的手写数字数据集)上训练它。模型几分钟就能训练完,我们可以方便地做各种各样的实验。

后续讲解只会给出代码片段,完整的代码请参见 https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/ddpm 。git clone 仓库并安装后,可以直接运行目录里的main.py训练模型并采样。

获取数据集

PyTorch的torchvision提供了获取了MNIST的接口,我们只需要用下面的函数就可以生成MNIST的Dataset实例。参数中,root为数据集的下载路径,download为是否自动下载数据集。令download=True的话,第一次调用该函数时会自动下载数据集,而第二次之后就不用下载了,函数会读取存储在root里的数据。

1
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)

我们可以用下面的代码来下载MNIST并输出该数据集的一些信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torchvision
from torchvision.transforms import ToTensor
def download_dataset():
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp.jpg')
tensor = ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

if __name__ == '__main__':
download_dataset()

执行这段代码,输出大致为:

1
2
3
4
5
6
length of MNIST 60000
<PIL.Image.Image image mode=L size=28x28 at 0x7FB3F09CCE50>
9
torch.Size([1, 28, 28])
tensor(1.)
tensor(0.)

第一行输出表明,MNIST数据集里有60000张图片。而从第二行和第三行输出中,我们发现每一项数据由图片和标签组成,图片是大小为28x28的PIL格式的图片,标签表明该图片是哪个数字。我们可以用torchvision里的ToTensor()把PIL图片转成PyTorch张量,进一步查看图片的信息。最后三行输出表明,每一张图片都是单通道图片(灰度图),颜色值的取值范围是0~1。

我们可以查看一下每张图片的样子。如果你是在用带显示器的电脑,可以去掉img.show那一行的注释,直接查看图片;如果你是在用服务器,可以去img.save的路径里查看图片。该图片的应该长这个样子:

我们可以用下面的代码预处理数据并创建DataLoader。由于DDPM会把图像和正态分布关联起来,我们更希望图像颜色值的取值范围是[-1, 1]。为此,我们可以对图像做一个线性变换,减0.5再乘2。

1
2
3
4
5
6
def get_dataloader(batch_size: int):
transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)])
dataset = torchvision.datasets.MNIST(root='data/mnist',
transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)

DDPM 类

在代码中,我们要实现一个DDPM类。它维护了扩散过程中的一些常量(比如$\alpha$),并且可以计算正向过程和反向过程的结果。

先来实现一下DDPM类的初始化函数。一开始,我们遵从论文的配置,用torch.linspace(min_beta, max_beta, n_steps)min_betamax_beta线性地生成n_steps个时刻的$\beta$。接着,我们根据公式$\alpha_t=1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^t\alpha_i$,计算每个时刻的alphaalpha_bar。注意,为了方便实现,我们让t的取值从0开始,要比论文里的$t$少1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

class DDPM():

# n_steps 就是论文里的 T
def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
alphas = 1 - betas
alpha_bars = torch.empty_like(alphas)
product = 1
for i, alpha in enumerate(alphas):
product *= alpha
alpha_bars[i] = product
self.betas = betas
self.n_steps = n_steps
self.alphas = alphas
self.alpha_bars = alpha_bars

部分实现会让 DDPM 继承torch.nn.Module,但我认为这样不好。DDPM本身不是一个神经网络,它只是描述了前向过程和后向过程的一些计算。只有涉及可学习参数的神经网络类才应该继承 torch.nn.Module

准备好了变量后,我们可以来实现DDPM类的其他方法。先实现正向过程方法,该方法会根据公式$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_t}\epsilon_t$计算正向过程中的$\mathbf{x}_t$。

1
2
3
4
5
6
def sample_forward(self, x, t, eps=None):
alpha_bar = self.alpha_bars[t].reshape(-1, 1, 1, 1)
if eps is None:
eps = torch.randn_like(x)
res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
return res

这里要解释一些PyTorch编程上的细节。这份代码中,self.alpha_bars是一个一维Tensor。而在并行训练中,我们一般会令t为一个形状为(batch_size, )Tensor。PyTorch允许我们直接用self.alpha_bars[t]self.alpha_bars里取出batch_size个数,就像用一个普通的整型索引来从数组中取出一个数一样。有些实现会用torch.gatherself.alpha_bars里取数,其作用是一样的。

我们可以随机从训练集取图片做测试,看看它们在前向过程中是怎么逐步变成噪声的。

接下来实现反向过程。在反向过程中,DDPM会用神经网络预测每一轮去噪的均值,把$\mathbf{x}_t$复原回$\mathbf{x}_0$,以完成图像生成。反向过程即对应论文中的采样算法。

其实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def sample_backward(self, img_shape, net, device, simple_var=True):
x = torch.randn(img_shape).to(device)
net = net.to(device)
for t in range(self.n_steps - 1, -1, -1):
x = self.sample_backward_step(x, t, net, simple_var)
return x

def sample_backward_step(self, x_t, t, net, simple_var=True):
n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

其中,sample_backward是用来给外部调用的方法,而sample_backward_step是执行一步反向过程的方法。

sample_backward会随机生成纯噪声x(对应$\mathbf{x}_T$),再令tn_steps - 10,调用sample_backward_step

1
2
3
4
5
6
def sample_backward(self, img_shape, net, device, simple_var=True):
x = torch.randn(img_shape).to(device)
net = net.to(device)
for t in range(self.n_steps - 1, -1, -1):
x = self.sample_backward_step(x, t, net, simple_var)
return x

sample_backward_step中,我们先准备好这一步的神经网络输出eps。为此,我们要把整型的t转换成一个格式正确的Tensor。考虑到输入里可能有多个batch,我们先获取batch size n,再根据它来生成t_tensor

1
2
3
4
5
6
def sample_backward_step(self, x_t, t, net, simple_var=True):

n = x_t.shape[0]
t_tensor = torch.tensor([t] * n,
dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)

之后,我们来处理反向过程公式中的方差项。根据伪代码,我们仅在t非零的时候算方差项。方差项用到的方差有两种取值,效果差不多,我们用simple_var来控制选哪种取值方式。获取方差后,我们再随机采样一个噪声,根据公式,得到方差项。

1
2
3
4
5
6
7
8
9
10
if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (
1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

最后,我们把eps和方差项套入公式,得到这一步更新过后的图像x_t

1
2
3
4
5
6
mean = (x_t -
(1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) *
eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise

return x_t

稍后完成了训练后,我们再来看反向过程的输出结果。

训练算法

接下来,我们先跳过神经网络的实现,直接完成论文里的训练算法。

再回顾一遍伪代码。首先,我们要随机选取训练图片$\mathbf{x}_{0}$,随机生成当前要训练的时刻$t$,以及随机生成一个生成$\mathbf{x}_{t}$的高斯噪声。之后,我们把$\mathbf{x}_{t}$和$t$输入进神经网络,尝试预测噪声。最后,我们以预测噪声和实际噪声的均方误差为损失函数做梯度下降。

为此,我们可以用下面的代码实现训练。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from dldemos.ddpm.dataset import get_dataloader, get_img_shape
from dldemos.ddpm.ddpm import DDPM
import cv2
import numpy as np
import einops

batch_size = 512
n_epochs = 100


def train(ddpm: DDPM, net, device, ckpt_path):
# n_steps 就是公式里的 T
# net 是某个继承自 torch.nn.Module 的神经网络
n_steps = ddpm.n_steps
dataloader = get_dataloader(batch_size)
net = net.to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)

for e in range(n_epochs):
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))
loss = loss_fn(eps_theta, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net.state_dict(), ckpt_path)

代码的主要逻辑都在循环里。首先是完成训练数据$\mathbf{x}_{0}$、$t$、噪声的采样。采样$\mathbf{x}_{0}$的工作可以交给PyTorch的DataLoader完成,每轮遍历得到的x就是训练数据。$t$的采样可以用torch.randint函数随机从[0, n_steps - 1]取数。采样高斯噪声可以直接用torch.randn_like(x)生成一个和训练图片x形状一样的符合标准正态分布的图像。

1
2
3
4
5
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
t = torch.randint(0, n_steps, (current_batch_size, )).to(device)
eps = torch.randn_like(x).to(device)

之后计算$\mathbf{x}_{t}$并将其和$t$输入进神经网络net。计算$\mathbf{x}_{t}$的任务会由DDPM类的sample_forward方法完成,我们在上文已经实现了它。

1
2
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(current_batch_size, 1))

得到了预测的噪声eps_theta,我们调用PyTorch的API,算均方误差并调用优化器即可。

1
2
3
4
5
6
7
8
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), 1e-3)

...
loss = loss_fn(eps_theta, eps)
optimizer.zero_grad()
loss.backward()
optimizer.step()

去噪神经网络

在DDPM中,理论上我们可以用任意一种神经网络架构。但由于DDPM任务十分接近图像去噪任务,而U-Net又是去噪任务中最常见的网络架构,因此绝大多数DDPM都会使用基于U-Net的神经网络。

我一直想训练一个尽可能简单的模型。经过多次实验,我发现DDPM的神经网络很难训练。哪怕是对于比较简单的MNIST数据集,结构差一点的网络(比如纯ResNet)都不太行,只有带了残差块和时序编码的U-Net才能较好地完成去噪。注意力模块倒是可以不用加上。

由于神经网络结构并不是DDPM学习的重点,我这里就不对U-Net的写法做解说,而是直接贴上代码了。代码中大部分内容都和普通的U-Net无异。唯一要注意的地方就是时序编码。去噪网络的输入除了图像外,还有一个时间戳t。我们要考虑怎么把t的信息和输入图像信息融合起来。大部分人的做法是对t进行Transformer中的位置编码,把该编码加到图像的每一处上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import torch
import torch.nn as nn
import torch.nn.functional as F
from dldemos.ddpm.dataset import get_img_shape


class PositionalEncoding(nn.Module):

def __init__(self, max_seq_len: int, d_model: int):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

pe = torch.zeros(max_seq_len, d_model)
i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(max_seq_len, d_model)

self.embedding = nn.Embedding(max_seq_len, d_model)
self.embedding.weight.data = pe
self.embedding.requires_grad_(False)

def forward(self, t):
return self.embedding(t)


class ResidualBlock(nn.Module):

def __init__(self, in_c: int, out_c: int):
super().__init__()
self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(out_c)
self.actvation1 = nn.ReLU()
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(out_c)
self.actvation2 = nn.ReLU()
if in_c != out_c:
self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, 1),
nn.BatchNorm2d(out_c))
else:
self.shortcut = nn.Identity()

def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.actvation1(x)
x = self.conv2(x)
x = self.bn2(x)
x += self.shortcut(input)
x = self.actvation2(x)
return x


class ConvNet(nn.Module):

def __init__(self,
n_steps,
intermediate_channels=[10, 20, 40],
pe_dim=10,
insert_t_to_all_layers=False):
super().__init__()
C, H, W = get_img_shape() # 1, 28, 28
self.pe = PositionalEncoding(n_steps, pe_dim)

self.pe_linears = nn.ModuleList()
self.all_t = insert_t_to_all_layers
if not insert_t_to_all_layers:
self.pe_linears.append(nn.Linear(pe_dim, C))

self.residual_blocks = nn.ModuleList()
prev_channel = C
for channel in intermediate_channels:
self.residual_blocks.append(ResidualBlock(prev_channel, channel))
if insert_t_to_all_layers:
self.pe_linears.append(nn.Linear(pe_dim, prev_channel))
else:
self.pe_linears.append(None)
prev_channel = channel
self.output_layer = nn.Conv2d(prev_channel, C, 3, 1, 1)

def forward(self, x, t):
n = t.shape[0]
t = self.pe(t)
for m_x, m_t in zip(self.residual_blocks, self.pe_linears):
if m_t is not None:
pe = m_t(t).reshape(n, -1, 1, 1)
x = x + pe
x = m_x(x)
x = self.output_layer(x)
return x


class UnetBlock(nn.Module):

def __init__(self, shape, in_c, out_c, residual=False):
super().__init__()
self.ln = nn.LayerNorm(shape)
self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.activation = nn.ReLU()
self.residual = residual
if residual:
if in_c == out_c:
self.residual_conv = nn.Identity()
else:
self.residual_conv = nn.Conv2d(in_c, out_c, 1)

def forward(self, x):
out = self.ln(x)
out = self.conv1(out)
out = self.activation(out)
out = self.conv2(out)
if self.residual:
out += self.residual_conv(x)
out = self.activation(out)
return out


class UNet(nn.Module):

def __init__(self,
n_steps,
channels=[10, 20, 40, 80],
pe_dim=10,
residual=False) -> None:
super().__init__()
C, H, W = get_img_shape()
layers = len(channels)
Hs = [H]
Ws = [W]
cH = H
cW = W
for _ in range(layers - 1):
cH //= 2
cW //= 2
Hs.append(cH)
Ws.append(cW)

self.pe = PositionalEncoding(n_steps, pe_dim)

self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.pe_linears_en = nn.ModuleList()
self.pe_linears_de = nn.ModuleList()
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
prev_channel = C
for channel, cH, cW in zip(channels[0:-1], Hs[0:-1], Ws[0:-1]):
self.pe_linears_en.append(
nn.Sequential(nn.Linear(pe_dim, prev_channel), nn.ReLU(),
nn.Linear(prev_channel, prev_channel)))
self.encoders.append(
nn.Sequential(
UnetBlock((prev_channel, cH, cW),
prev_channel,
channel,
residual=residual),
UnetBlock((channel, cH, cW),
channel,
channel,
residual=residual)))
self.downs.append(nn.Conv2d(channel, channel, 2, 2))
prev_channel = channel

self.pe_mid = nn.Linear(pe_dim, prev_channel)
channel = channels[-1]
self.mid = nn.Sequential(
UnetBlock((prev_channel, Hs[-1], Ws[-1]),
prev_channel,
channel,
residual=residual),
UnetBlock((channel, Hs[-1], Ws[-1]),
channel,
channel,
residual=residual),
)
prev_channel = channel
for channel, cH, cW in zip(channels[-2::-1], Hs[-2::-1], Ws[-2::-1]):
self.pe_linears_de.append(nn.Linear(pe_dim, prev_channel))
self.ups.append(nn.ConvTranspose2d(prev_channel, channel, 2, 2))
self.decoders.append(
nn.Sequential(
UnetBlock((channel * 2, cH, cW),
channel * 2,
channel,
residual=residual),
UnetBlock((channel, cH, cW),
channel,
channel,
residual=residual)))

prev_channel = channel

self.conv_out = nn.Conv2d(prev_channel, C, 3, 1, 1)

def forward(self, x, t):
n = t.shape[0]
t = self.pe(t)
encoder_outs = []
for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders,
self.downs):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = encoder(x + pe)
encoder_outs.append(x)
x = down(x)
pe = self.pe_mid(t).reshape(n, -1, 1, 1)
x = self.mid(x + pe)
for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de,
self.decoders, self.ups,
encoder_outs[::-1]):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = up(x)

pad_x = encoder_out.shape[2] - x.shape[2]
pad_y = encoder_out.shape[3] - x.shape[3]
x = F.pad(x, (pad_x // 2, pad_x - pad_x // 2, pad_y // 2,
pad_y - pad_y // 2))
x = torch.cat((encoder_out, x), dim=1)
x = decoder(x + pe)
x = self.conv_out(x)
return x


convnet_small_cfg = {
'type': 'ConvNet',
'intermediate_channels': [10, 20],
'pe_dim': 128
}

convnet_medium_cfg = {
'type': 'ConvNet',
'intermediate_channels': [10, 10, 20, 20, 40, 40, 80, 80],
'pe_dim': 256,
'insert_t_to_all_layers': True
}
convnet_big_cfg = {
'type': 'ConvNet',
'intermediate_channels': [20, 20, 40, 40, 80, 80, 160, 160],
'pe_dim': 256,
'insert_t_to_all_layers': True
}

unet_1_cfg = {'type': 'UNet', 'channels': [10, 20, 40, 80], 'pe_dim': 128}
unet_res_cfg = {
'type': 'UNet',
'channels': [10, 20, 40, 80],
'pe_dim': 128,
'residual': True
}


def build_network(config: dict, n_steps):
network_type = config.pop('type')
if network_type == 'ConvNet':
network_cls = ConvNet
elif network_type == 'UNet':
network_cls = UNet

network = network_cls(n_steps, **config)
return network

实验结果与采样

把之前的所有代码综合一下,我们以带残差块的U-Net为去噪网络,执行训练。

1
2
3
4
5
6
7
8
9
10
11
if __name__ == '__main__':
n_steps = 1000
config_id = 4
device = 'cuda'
model_path = 'dldemos/ddpm/model_unet_res.pth'

config = unet_res_cfg
net = build_network(config, n_steps)
ddpm = DDPM(device, n_steps)

train(ddpm, net, device=device, ckpt_path=model_path)

按照默认训练配置,在3090上花5分钟不到,训练30~40个epoch即可让网络基本收敛。最终收敛时loss在0.023~0.024左右。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
batch size: 512
epoch 0 loss: 0.23103461712201437 elapsed 7.01s
epoch 1 loss: 0.0627968365987142 elapsed 13.66s
epoch 2 loss: 0.04828845852613449 elapsed 20.25s
epoch 3 loss: 0.04148937337398529 elapsed 26.80s
epoch 4 loss: 0.03801360730528831 elapsed 33.37s
epoch 5 loss: 0.03604260584712028 elapsed 39.96s
epoch 6 loss: 0.03357676289876302 elapsed 46.57s
epoch 7 loss: 0.0335664684087038 elapsed 53.15s
...
epoch 30 loss: 0.026149748386939366 elapsed 204.64s
epoch 31 loss: 0.025854381563266117 elapsed 211.24s
epoch 32 loss: 0.02589433005253474 elapsed 217.84s
epoch 33 loss: 0.026276464049021404 elapsed 224.41s
...
epoch 96 loss: 0.023299352884292603 elapsed 640.25s
epoch 97 loss: 0.023460942271351815 elapsed 646.90s
epoch 98 loss: 0.023584651704629263 elapsed 653.54s
epoch 99 loss: 0.02364126600921154 elapsed 660.22s

训练这个网络时,并没有特别好的测试指标,我们只能通过观察采样图像来评价网络的表现。我们可以用下面的代码调用DDPM的反向传播方法,生成多幅图像并保存下来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def sample_imgs(ddpm,
net,
output_path,
n_sample=81,
device='cuda',
simple_var=True):
net = net.to(device)
net = net.eval()
with torch.no_grad():
shape = (n_sample, *get_img_shape()) # 1, 3, 28, 28
imgs = ddpm.sample_backward(shape,
net,
device=device,
simple_var=simple_var).detach().cpu()
imgs = (imgs + 1) / 2 * 255
imgs = imgs.clamp(0, 255)
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(n_sample**0.5))

imgs = imgs.numpy().astype(np.uint8)

cv2.imwrite(output_path, imgs)

一切顺利的话,我们可以得到一些不错的生成结果。下图是我得到的一些生成图片:

大部分生成的图片都对应一个阿拉伯数字,它们和训练集MNIST里的图片非常接近。这算是一个不错的生成结果。

如果神经网络的拟合能力较弱,生成结果就会差很多。下图是我训练一个简单的ResNet后得到的采样结果:

可以看出,每幅图片都很乱,基本对应不上一个数字。这就是一个较差的训练结果。

如果网络再差一点,可能会生成纯黑或者纯白的图片。这是因为网络的预测结果不准,在反向过程中,图像的均值不断偏移,偏移到远大于1或者远小于-1的值了。

总结一下,在复现DDPM时,最主要是要学习DDPM论文的两个算法,即训练算法和采样算法。两个算法很简单,可以轻松地把它们翻译成代码。而为了成功完成复现,还需要花一点心思在编写U-Net上,尤其是注意处理时间戳的部分。

前段时间我写了一篇VQVAE的解读,现在再补充一篇VQVAE的PyTorch实现教程。在这个项目中,我们会实现VQVAE论文,在MNIST和CelebAHQ两个数据集上完成图像生成。具体来说,我们会先实现并训练一个图像压缩网络VQVAE,它能把真实图像编码成压缩图像,或者把压缩图像解码回真实图像。之后,我们会训练一个生成压缩图像的生成网络PixelCNN。

代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/VQVAE

项目运行示例

如果你只是想快速地把项目运行起来,可以只阅读本节。

在本地安装好项目后,运行python dldemos/VQVAE/dataset.py来下载MNIST数据集。之后运行python dldemos/VQVAE/main.py,这个脚本会完成以下四个任务:

  1. 训练VQVAE
  2. 用VQVAE重建数据集里的随机数据
  3. 训练PixelCNN
  4. 用PixelCNN+VQVAE随机生成图片

第二步得到的重建结果大致如下(每对图片中左图是原图,右图是重建结果):

第四步得到的随机生成结果大致如下:

如果你要使用CelebAHQ数据集,请照着下一节的指示把CelebAHQ下载到指定目录,再执行python dldemos/VQVAE/main.py -c 4

数据集准备

MNIST数据集可以用PyTorch的API自动下载。我们可以用下面的代码下载MNIST数据集并查看数据的格式。从输出中可知,MNIST的图片形状为[1, 28, 28],颜色取值范围为[0, 1]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def download_mnist():
mnist = torchvision.datasets.MNIST(root='data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp_mnist.jpg')
tensor = transforms.ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

我们可以用下面的代码把它封成简单的Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class MNISTImageDataset(Dataset):

def __init__(self, img_shape=(28, 28)):
super().__init__()
self.img_shape = img_shape
self.mnist = torchvision.datasets.MNIST(root='data/mnist')

def __len__(self):
return len(self.mnist)

def __getitem__(self, index: int):
img = self.mnist[index][0]
pipeline = transforms.Compose(
[transforms.Resize(self.img_shape),
transforms.ToTensor()])
return pipeline(img)

接下来准备CelebAHQ。CelebAHQ数据集原本的图像大小是1024x1024,但我们这个项目用不到这么大的图片。我在kaggle上找到了一个256x256的CelebAHQ (https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256),所有文件加起来只有300MB左右,很适合我们项目。请在该页面下载压缩包,并把压缩包解压到项目的`data/celebA/celeba_hq_256`目录下。

下载完数据后,我们可以写一个简单的从目录中读取图片的Dataset类。和MNIST的预处理流程不同,我这里给CelebAHQ的图片加了一个中心裁剪的操作,一来可以让人脸占比更大,便于模型学习,二来可以让该类兼容CelebA数据集(CelebA数据集的图片不是正方形,需要裁剪)。这个操作是可选的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class CelebADataset(Dataset):

def __init__(self, root, img_shape=(64, 64)):
super().__init__()
self.root = root
self.img_shape = img_shape
self.filenames = sorted(os.listdir(root))

def __len__(self) -> int:
return len(self.filenames)

def __getitem__(self, index: int):
path = os.path.join(self.root, self.filenames[index])
img = Image.open(path)
pipeline = transforms.Compose([
transforms.CenterCrop(168),
transforms.Resize(self.img_shape),
transforms.ToTensor()
])
return pipeline(img)

有了数据集类后,我们可以用它们生成Dataloader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
CELEBA_DIR = 'data/celebA/img_align_celeba'
CELEBA_HQ_DIR = 'data/celebA/celeba_hq_256'
def get_dataloader(type,
batch_size,
img_shape=None,
dist_train=False,
num_workers=4,
**kwargs):
if type == 'CelebA':
if img_shape is not None:
kwargs['img_shape'] = img_shape
dataset = CelebADataset(CELEBA_DIR, **kwargs)
elif type == 'CelebAHQ':
if img_shape is not None:
kwargs['img_shape'] = img_shape
dataset = CelebADataset(CELEBA_HQ_DIR, **kwargs)
elif type == 'MNIST':
if img_shape is not None:
dataset = MNISTImageDataset(img_shape)
else:
dataset = MNISTImageDataset()
if dist_train:
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers)
return dataloader, sampler
else:
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return dataloader

我们可以利用Dataloader来查看CelebAHQ数据集的内容及数据格式。

1
2
3
4
5
6
7
8
9
10
11
12
13
if os.path.exists(CELEBA_HQ_DIR):
dataloader = get_dataloader('CelebAHQ', 16)
img = next(iter(dataloader))
print(img.shape)
N = img.shape[0]
img = einops.rearrange(img,
'(n1 n2) c h w -> c (n1 h) (n2 w)',
n1=int(N**0.5))
print(img.shape)
print(img.max())
print(img.min())
img = transforms.ToPILImage()(img)
img.save('work_dirs/tmp_celebahq.jpg')

从输出中可知,CelebAHQ的颜色取值范围同样是[0, 1]。经我们的预处理流水线得到的图片如下。

实现并训练 VQVAE

要用VQVAE做图像生成,其实要训练两个模型:一个是用于压缩图像的VQVAE,另一个是生成压缩图像的PixelCNN。这两个模型是可以分开训练的。我们先来实现并训练VQVAE。

VQVAE的架构非常简单:一个编码器,一个解码器,外加中间一个嵌入层。损失函数为图像的重建误差与编码器输出与其对应嵌入之间的误差。

VQVAE的编码器和解码器的结构也很简单,仅由普通的上/下采样层和残差块组成。具体来说,编码器先是有两个3x3卷积+2倍下采样卷积的模块,再有两个残差块(ReLU, 3x3卷积, ReLU, 1x1卷积);解码器则反过来,先有两个残差块,再有两个3x3卷积+2倍上采样反卷积的模块。为了让代码看起来更清楚一点,我们不用过度封装,仅实现一个残差块模块,再用残差块和PyTorch自带模块拼成VQVAE。

先实现残差块。注意,由于模型比较简单,残差块内部和VQVAE其他地方都可以不使用BatchNorm。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class ResidualBlock(nn.Module):

def __init__(self, dim):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)

def forward(self, x):
tmp = self.relu(x)
tmp = self.conv1(tmp)
tmp = self.relu(tmp)
tmp = self.conv2(tmp)
return x + tmp

有了残差块类后,我们可以直接实现VQVAE类。我们先在初始化函数里把模块按顺序搭好。编码器和解码器的结构按前文的描述搭起来即可。嵌入空间(codebook)其实就是个普通的嵌入层。此处我仿照他人代码给嵌入层显式初始化参数,但实测下来和默认的初始化参数方式差别不大。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class VQVAE(nn.Module):

def __init__(self, input_dim, dim, n_embedding):
super().__init__()
self.encoder = nn.Sequential(nn.Conv2d(input_dim, dim, 4, 2, 1),
nn.ReLU(), nn.Conv2d(dim, dim, 4, 2, 1),
nn.ReLU(), nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim), ResidualBlock(dim))
self.vq_embedding = nn.Embedding(n_embedding, dim)
self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding,
1.0 / n_embedding)
self.decoder = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim), ResidualBlock(dim),
nn.ConvTranspose2d(dim, dim, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1))
self.n_downsample = 2

之后,我们来实现模型的前向传播。这里的逻辑就略显复杂了。整体来看,这个函数完成了编码、取最近邻、解码这三步。其中,取最近邻的部分最为复杂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def forward(self, x):
# encode
ze = self.encoder(x)

# ze: [N, C, H, W]
# embedding [K, C]
embedding = self.vq_embedding.weight.data
N, C, H, W = ze.shape
K, _ = embedding.shape
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
nearest_neighbor = torch.argmin(distance, 1)
# make C to the second dim
zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)
# stop gradient
decoder_input = ze + (zq - ze).detach()

# decode
x_hat = self.decoder(decoder_input)
return x_hat, ze, zq

我们来详细看一看取最近邻的实现。取最近邻时,我们要用到两块数据:编码器输出ze与嵌入矩阵embeddingze可以看成一个形状为[N, H, W]的数组,数组存储了长度为C的向量。而嵌入矩阵里有K个长度为C的向量。
1
2
3
4
5
# ze: [N, C, H, W]
# embedding [K, C]
embedding = self.vq_embedding.weight.data
N, C, H, W = ze.shape
K, _ = embedding.shape

为了求N*H*W个向量在嵌入矩阵里的最近邻,我们要先算这每个向量与嵌入矩阵里K个向量的距离。在算距离前,我们要把embeddingze的形状变换一下,保证(embedding_broadcast - ze_broadcast)**2的形状为[N, K, C, H, W]。我们对这个临时结果的第2号维度(C所在维度)求和,得到形状为[N, K, H, W]distance。它的含义是,对于N*H*W个向量,每个向量到嵌入空间里K个向量的距离分别是多少。
1
2
3
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)

有了距离张量后,我们再对其1号维度(K所在维度)求最近邻所在下标。

1
nearest_neighbor = torch.argmin(distance, 1)

有了下标后,我们可以用self.vq_embedding(nearest_neighbor)从嵌入空间取出最近邻了。别忘了,nearest_neighbor的形状是[N, H, W]self.vq_embedding(nearest_neighbor)的形状会是[N, H, W, C]。我们还要把C维度转置一下。

1
2
# make C to the second dim
zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)

最后,我们用论文里提到的停止梯度算子,把zq变形一下。这样,算误差的时候用的是zq,算梯度时ze会接收解码器传来的梯度。

1
2
# stop gradient
decoder_input = ze + (zq - ze).detach()

求最近邻的部分就到此结束了。最后再补充一句,前向传播函数不仅返回了重建结果x_hat,还返回了ze, zq。这是因为我们待会要在训练时根据ze, zq求损失函数。

准备好了模型类后,假设我们已经用某些超参数初始化好了模型model,我们可以用下面的代码训练VQVAE。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def train_vqvae(model: VQVAE,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/model.pth',
batch_size=64,
dataset_type='MNIST',
lr=1e-3,
n_epochs=100,
l_w_embedding=1,
l_w_commitment=0.25):
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr)
mse_loss = nn.MSELoss()
tic = time.time()
for e in range(n_epochs):
total_loss = 0

for x in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)

x_hat, ze, zq = model(x)
l_reconstruct = mse_loss(x, x_hat)
l_embedding = mse_loss(ze.detach(), zq)
l_commitment = mse_loss(ze, zq.detach())
loss = l_reconstruct + \
l_w_embedding * l_embedding + l_w_commitment * l_commitment
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), ckpt_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')

先看一下训练函数的参数。其他参数都没什么特别的,只有误差权重l_w_embedding=1,l_w_commitment=0.25需要讨论一下。误差函数有三项,但论文只给了第三项的权重(0.25),默认第二项的权重为1。我在实现时把第二项的权重l_w_embedding也加上了。

1
2
3
4
5
6
7
8
9
10
def train_vqvae(model: VQVAE,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/model.pth',
batch_size=64,
dataset_type='MNIST',
lr=1e-3,
n_epochs=100,
l_w_embedding=1,
l_w_commitment=0.25):

再来把函数体过一遍。一开始,我们可以用传来的参数把dataloader初始化一下。

1
2
3
4
5
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)

再把模型的状态调好,并准备好优化器和算均方误差的函数。

1
2
3
4
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr)
mse_loss = nn.MSELoss()

准备好变量后,进入训练循环。训练的过程比较常规,唯一要注意的就是误差计算部分。由于我们把复杂的逻辑都放在了模型类中,这里我们可以直接先用model(x)得到重建图像x_hat和算误差的ze, zq,再根据论文里的公式算3个均方误差,最后求一个加权和,代码比较简明。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for e in range(n_epochs):
for x in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)

x_hat, ze, zq = model(x)
l_reconstruct = mse_loss(x, x_hat)
l_embedding = mse_loss(ze.detach(), zq)
l_commitment = mse_loss(ze, zq.detach())
loss = l_reconstruct + \
l_w_embedding * l_embedding + l_w_commitment * l_commitment
optimizer.zero_grad()
loss.backward()
optimizer.step()

训练完毕后,我们可以用下面的代码来测试VQVAE的重建效果。所谓重建,就是模拟训练的过程,随机取一些图片,先编码后解码,看解码出来的图片和原图片是否一致。为了获取重建后的图片,我们只需要直接执行前向传播函数model(x)即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def reconstruct(model, x, device, dataset_type='MNIST'):
model.to(device)
model.eval()
with torch.no_grad():
x_hat, _, _ = model(x)
n = x.shape[0]
n1 = int(n**0.5)
x_cat = torch.concat((x, x_hat), 3)
x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)
x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)
if dataset_type == 'CelebA' or dataset_type == 'CelebAHQ':
x_cat = cv2.cvtColor(x_cat, cv2.COLOR_RGB2BGR)
cv2.imwrite(f'work_dirs/vqvae_reconstruct_{dataset_type}.jpg', x_cat)

vqvae = ...
dataloader = get_dataloader(...)
img = next(iter(dataloader)).to(device)
reconstruct(vqvae, img, device, cfg['dataset_type'])

训练压缩图像生成模型 PixelCNN

有了一个VQVAE后,我们要用另一个模型对VQVAE的离散空间采样,也就是训练一个能生成压缩图片的模型。我们可以按照VQVAE论文的方法,使用PixelCNN来生成压缩图片。

PixelCNN 的原理及实现方法就不在这里过多介绍了。详情可以参见我之前的PixelCNN解读文章。简单来说,PixelCNN给每个像素从左到右,从上到下地编了一个序号,让每个像素仅由之前所有像素决定。采样时,PixelCNN按序号从左上到右下逐个生成图像的每一个像素;训练时,PixelCNN使用了某种掩码机制,使得每个像素只能看到编号更小的像素,并行地输出每一个像素的生成结果。

PixelCNN具体的训练示意图如下。模型的输入是一幅图片,每个像素的取值是0~255;模型给图片的每个像素输出了一个概率分布,即表示此处颜色取0,取1,……,取255的概率。由于神经网络假设数据的输入符合标准正态分布,我们要在数据输入前把整型的颜色转换成0~1之间的浮点数。最简单的转换方法是除以255。

以上是训练PixelCNN生成普通图片的过程。而在训练PixelCNN生成压缩图片时,上述过程需要修改。压缩图片的取值是离散编码。离散编码和颜色值不同,它不是连续的。你可以说颜色1和颜色0、2相近,但不能说离散编码1和离散编码0、2相近。因此,为了让PixelCNN建模离散编码,需要把原来的除以255操作换成一个嵌入层,使得网络能够读取离散编码。

反映在代码中,假设我们已经有了一个普通的PixelCNN模型GatedPixelCNN,我们需要在整个模型的最前面套一个嵌入层,嵌入层的嵌入个数等于离散编码的个数(color_level),嵌入长度等于模型的特征长度(p)。由于嵌入层会直接输出一个长度为p的向量,我们还需要把第一个模块的输入通道数改成p

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from dldemos.pixelcnn.model import GatedPixelCNN, GatedBlock

import torch.nn as nn


class PixelCNNWithEmbedding(GatedPixelCNN):

def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__(n_blocks, p, linear_dim, bn, color_level)
self.embedding = nn.Embedding(color_level, p)
self.block1 = GatedBlock('A', p, p, bn)

def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 3, 1, 2).contiguous()
return super().forward(x)

有了一个能处理离散编码的PixelCNN后,我们可以用下面的代码来训练PixelCNN。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def train_generative_model(vqvae: VQVAE,
model,
img_shape=None,
device='cuda',
ckpt_path='dldemos/VQVAE/gen_model.pth',
dataset_type='MNIST',
batch_size=64,
n_epochs=50):
print('batch size:', batch_size)
dataloader = get_dataloader(dataset_type,
batch_size,
img_shape=img_shape,
use_lmdb=USE_LMDB)
vqvae.to(device)
vqvae.eval()
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
loss_fn = nn.CrossEntropyLoss()
tic = time.time()
for e in range(n_epochs):
total_loss = 0
for x in dataloader:
current_batch_size = x.shape[0]
with torch.no_grad():
x = x.to(device)
x = vqvae.encode(x)

predict_x = model(x)
loss = loss_fn(predict_x, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), ckpt_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')
gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],
cfg['pixelcnn_dim'],
cfg['pixelcnn_linear_dim'], True,
cfg['n_embedding'])
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
train_generative_model(vqvae,
gen_model,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['gen_model_path'],
dataset_type=cfg['dataset_type'],
batch_size=cfg['batch_size_2'],
n_epochs=cfg['n_epochs_2'])

训练部分的核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
loss_fn = nn.CrossEntropyLoss()
for x in dataloader:
current_batch_size = x.shape[0]
with torch.no_grad():
x = x.to(device)
x = vqvae.encode(x)

predict_x = model(x)
loss = loss_fn(predict_x, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()

这段代码的意思是说,从训练集里随机取图片x,再将图片压缩成离散编码x = vqvae.encode(x)。这时,x既是PixelCNN的输入,也是PixelCNN的拟合目标。把它输入进PixelCNN,PixelCNN会输出每个像素的概率分布。用交叉熵损失函数约束输出结果即可。

训练完毕后,我们可以用下面的函数来完成整套图像生成流水线。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def sample_imgs(vqvae: VQVAE,
gen_model,
img_shape,
n_sample=81,
device='cuda',
dataset_type='MNIST'):
vqvae = vqvae.to(device)
vqvae.eval()
gen_model = gen_model.to(device)
gen_model.eval()

C, H, W = img_shape
H, W = vqvae.get_latent_HW((C, H, W))
input_shape = (n_sample, H, W)
x = torch.zeros(input_shape).to(device).to(torch.long)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = gen_model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist, 1)
x[:, i, j] = pixel[:, 0]

imgs = vqvae.decode(x)

imgs = imgs * 255
imgs = imgs.clip(0, 255)
imgs = einops.rearrange(imgs,
'(n1 n2) c h w -> (n1 h) (n2 w) c',
n1=int(n_sample**0.5))

imgs = imgs.detach().cpu().numpy().astype(np.uint8)
if dataset_type == 'CelebA' or dataset_type == 'CelebAHQ':
imgs = cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR)

cv2.imwrite(f'work_dirs/vqvae_sample_{dataset_type}.jpg', imgs)

抛掉前后处理,和图像生成有关的代码如下。一开始,我们要随便创建一个空图片x,用于储存PixelCNN生成的压缩图片。之后,我们按顺序遍历每个像素,把当前图片输入进PixelCNN,让PixelCNN预测下一个像素的概率分布prob_dist。我们再用torch.multinomial从概率分布中采样,把采样的结果填回图片。遍历结束后,我们用VQVAE的解码器把压缩图片变成真实图片。

1
2
3
4
5
6
7
8
9
10
11
12
13
C, H, W = img_shape
H, W = vqvae.get_latent_HW((C, H, W))
input_shape = (n_sample, H, W)
x = torch.zeros(input_shape).to(device).to(torch.long)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = gen_model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist, 1)
x[:, i, j] = pixel[:, 0]

imgs = vqvae.decode(x)

至此,我们已经实现了用VQVAE做图像生成的四个任务:训练VQVAE、重建图像、训练PixelCNN、随机生成图像。完整的main函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
if __name__ == '__main__':
os.makedirs('work_dirs', exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('-c', type=int, default=0)
parser.add_argument('-d', type=int, default=0)
args = parser.parse_args()
cfg = get_cfg(args.c)

device = f'cuda:{args.d}'

img_shape = cfg['img_shape']

vqvae = VQVAE(img_shape[0], cfg['dim'], cfg['n_embedding'])
gen_model = PixelCNNWithEmbedding(cfg['pixelcnn_n_blocks'],
cfg['pixelcnn_dim'],
cfg['pixelcnn_linear_dim'], True,
cfg['n_embedding'])
# 1. Train VQVAE
train_vqvae(vqvae,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['vqvae_path'],
batch_size=cfg['batch_size'],
dataset_type=cfg['dataset_type'],
lr=cfg['lr'],
n_epochs=cfg['n_epochs'],
l_w_embedding=cfg['l_w_embedding'],
l_w_commitment=cfg['l_w_commitment'])

# 2. Test VQVAE by visualizaing reconstruction result
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
dataloader = get_dataloader(cfg['dataset_type'],
16,
img_shape=(img_shape[1], img_shape[2]))
img = next(iter(dataloader)).to(device)
reconstruct(vqvae, img, device, cfg['dataset_type'])

# 3. Train Generative model (Gated PixelCNN in our project)
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))

train_generative_model(vqvae,
gen_model,
img_shape=(img_shape[1], img_shape[2]),
device=device,
ckpt_path=cfg['gen_model_path'],
dataset_type=cfg['dataset_type'],
batch_size=cfg['batch_size_2'],
n_epochs=cfg['n_epochs_2'])

# 4. Sample VQVAE
vqvae.load_state_dict(torch.load(cfg['vqvae_path']))
gen_model.load_state_dict(torch.load(cfg['gen_model_path']))
sample_imgs(vqvae,
gen_model,
cfg['img_shape'],
device=device,
dataset_type=cfg['dataset_type'])

实验

VQVAE有两个超参数:嵌入个数n_embedding、特征向量长度dim。论文中n_embedding=512dim=256。而经我实现发现,用更小的参数量也能达到不错的效果。

所有实验的配置文件我都放在了该项目目录下config.py文件中。对于MNIST数据集,我使用的模型超参数为:dim=32, n_embedding=32。VQVAE重建结果如下所示。可以说重建得几乎完美(每对图片左图为原图,右图为重建结果)。

而对于CelebAHQ数据集,我测试了不同输入尺寸下的不同VQVAE,共有4组配置。

  1. shape=(3, 128, 128) dim=128 n_embedding=64
  2. shape=(3, 128, 128) dim=128 n_embedding=128
  3. shape=(3, 64, 64) dim=128 n_embedding=64
  4. shape=(3, 64, 64) dim=128 n_embedding=32

实验的结果很好预测。对于同尺寸的图片,嵌入数越多重建效果越好。这里我只展示下第一组和第二组的重建结果。

可以看出,VQVAE的重建效果还不错。但由于只使用了均方误差,重建图片在细节上还是比较模糊。重建效果还是很重要的,它决定了该方法做图像生成的质量上限。后续有很多工作都试图提升VQVAE的重建效果。

接下来来看一下随机图像生成的实验。PixelCNN主要有模块数n_blocks、特征长度dim,输出线性层特征长度linear_dim这三个超参数。其中模块数一般是固定的,而输出线性层就被用了一次,其特征长度的影响不大。最需要调节的是特征长度dim。对于MNIST,我的超参数设置为

  • n_blocks=15 dim=128 linear_dim=32.

对于CelebAHQ,我的超参数设置为

  • n_blocks=15 dim=384 linear_dim=256.

PixelCNN的训练时间主要由输入图片尺寸和dim决定,训练难度主要由VQVAE的嵌入个数(即多分类的类别数)决定。PixelCNN训起来很花时间。如果时间有限,在CelebAHQ上建议只训练最小最简单的第4组配置。我在项目中提供了PixelCNN的并行训练脚本,比如用下面的命令可以用4张卡在1号配置下并行训练。

1
torchrun --nproc_per_node=4 dldemos/VQVAE/dist_train_pixelcnn.py -c 1

来看一下实验结果。MNIST上的采样结果还是非常不错的。

CelebAHQ上的结果会差一点。以下是第4组配置(图像边长64,嵌入数32)的采样结果。大部分图片都还行,起码看得出是一张人脸。但64x64的图片本来就分辨率不高,加上VQVAE解码的损耗,放大来看人脸还是比较模糊的。

第1组配置(图像边长128,嵌入数64)的PixelCNN实在训练得太慢了,我只训了一个半成品模型。由于部分生成结果比较吓人,我只挑了几个还能看得过去的生成结果。可以看出,如果把模型训完的话,边长128的模型肯定比边长64的模型效果更好。

参考资料

网上几乎找不到在CelebAHQ上训练的VQVAE PyTorch项目。我在实现这份代码时,参考了以下项目:

实验经历分享

别看VQVAE的代码不难,我做这些实验时还是经历了不少波折的。

一开始,我花一天就把代码写完了,并完成了MNIST上的实验。我觉得在MNIST上做实验的难度太低,不过瘾,就准备把数据集换成CelebA再做一些实验。结果这一做就是两个星期。

换成CelebA后,我碰到的第一个问题是VQVAE训练速度太慢。我尝试减半模型参数,训练时间却减小得不明显。我大致猜出是数据读取占用了大量时间,用性能分析工具一查,果然如此。原来我在DataLoader中一直只用了一个线程,加上num_workers=4就好了。我还把数据集打包成LMDB格式进一步加快数据读取速度。

之后,我又发现VQVAE在CelebA上的重建效果很差。我尝试增加模型参数,没起作用。我又怀疑是64x64的图片质量太低,模型学不到东西,就尝试把输入尺寸改成128x128,并把数据集从CelebA换成CelebAHQ,重建效果依然不行。我调了很多参数,发现了一些奇怪的现象:在嵌入层前使用和不使用BatchNorm对结果的影响很大,且显式初始化嵌入层会让模型的误差一直居高不下。我实在是找不到问题,就拿代码对着别人的PyTorch实现一行一行比较过去。总算,我发现我在使用嵌入层时是用vq_embedding.weight.data[x](因为前面已经获取了这个矩阵,这样写比较自然),别人是用vq_embedding(x)。我的写法会把嵌入层排除在梯度计算外,嵌入层根本得不到优化。我说怎么换了一个嵌入层的初始化方法模型就根本训不动了。改完bug之后,只训了5个epoch,新模型的误差比原来训练数小时的模型要低了。新模型的重建效果非常好。

总算,任务完成了一半,现在只剩PixelCNN要训练了。我先尝试训练输入为128x128,嵌入数64的模型,采样结果很差。为了加快实验速度,我把输入尺寸减小到64x64,再次训练,采样结果还是不行。根据我之前的经验,PixelCNN的训练难度主要取决于类别数。于是,我把嵌入的数量从64改成了32,并大幅增加PixelCNN的参数量,再次训练。过了很久,训练误差终于降到0.08左右。我一测,这次的采样结果还不错。

这样看来,之前的采样效果不好,是输入128x128,嵌入数64的实验太难了。我毕竟只是想做一个demo,在一个小型实验上成功就行了,没必要花时间去做更耗时的实验。按理说,我应该就此收手。但是,我就是咽不下这一口气,就是想在128x128的实验上成功。我再次加大了PixelCNN的参数量,用128x128的配置,大火慢炖,训练了一天一夜。第二天一早起来,我看到这回的误差也降到了0.08。上次的实验误差降到这个程度时实验已经成功了。我迫不及待地去测试采样效果,却发现采样效果还是稀烂。没办法,我选择投降,开始写这篇文章,准备收工。

写到PixelCNN介绍的那一章节时,我正准备讲解代码。看到PixelCNN训练之前预处理除以color_level那一行时,我楞了一下:这行代码是用来做什么的来着?这段代码全是从PixelCNN项目里复制过来的。当时是做普通图片的图像生成,所以要对输入颜色做一个预处理,把整数颜色变成0~1之间的浮点数。但现在是在生成压缩图片,不能这样处理啊!我恍然大悟,知道是在处理离散输入时做错了。应该多加一个嵌入层,把离散值转换成向量。由于VQVAE的重点不在生成模型上,原论文根本没有强调PixelCNN在离散编码上的实现细节。网上几乎所有文章也都没谈这一点。因此,我在实现PixelCNN时,直接不假思索地把原来的代码搬了过来,根本没想过这种地方会出现bug。

把这处bug改完后,我再次开启训练。这下所有模型的采样结果都正常了。误差降到0.5左右就已经有不错的采样结果了,原来我之前把误差降到0.08完全是无用功。太气人了。

这次的实验让我学到了很多东西。首先是PyTorch编程上的一些注意事项:

  • 调用embedding.weight.data[x]是传不了梯度的。
  • 如果读数据时有费时的处理操作(读写硬盘、解码),要在Dataloader里设置num_workers

另外,在测试一个模型是否实现成功时有一个重要的准则:

  • 不要仅在简单的数据集(如MNIST)上测试。测试成功可能只是暴力拟合的结果。只有在一个难度较大的数据集上测试成功才能说模型没有问题。

在观察模型是否训成功时,还需要注意:

  • 训练误差降低不代表模型更优。训练误差的评价方法和模型实际使用方法可能完全不同。不能像我这样偷懒不加测试指标。

除了学到的东西外,我还有一些感想。在别人的项目的基础上修改、照着他人代码复现、完全自己动手从零开始写,对于深度学习项目来说,这三种实现方式的难度是依次递增的。改别人的项目,你可能去配置文件里改一两个数字就行了。而照着他人代码复现,最起码你能把代码改成和他人的代码一模一样,然后再去比较哪一块错了。自己动手写,则是有bug都找不到可以参考的地方了。说深度学习的算法难以调试,难就难在这里。效果不好,你很难说清是训练代码错了、超参数没设置好、训练流程错了,或是测试代码错了。可以出错的地方太多了,通常的代码调试手段难以用在深度学习项目上。

对于想要在深度学习上有所建树的初学者,我建议一定要从零动手复现项目。很多工程经验是难以总结的,只有踩了一遍坑才能知道。除了凭借经验外,还可以掌握一些特定的工程方法来减少bug的出现。比如运行训练之前先拿性能工具分析一遍,看看代码是否有误,是否可以提速;又比如可以训练几步后看所有可学习参数是否被正确修改。

2022年中旬,以扩散模型为核心的图像生成模型将AI绘画带入了大众的视野。实际上,在更早的一年之前,就有了一个能根据文字生成高清图片的模型——VQGAN。VQGAN不仅本身具有强大的图像生成能力,更是传承了前作VQVAE把图像压缩成离散编码的思想,推广了「先压缩,再生成」的两阶段图像生成思路,启发了无数后续工作。

VQGAN生成出的高清图片

在这篇文章中,我将对VQGAN的论文和源码中的关键部分做出解读,提炼出VQGAN中的关键知识点。由于VQGAN的核心思想和VQVAE如出一辙,我不会过多地介绍VQGAN的核心思想,强烈建议读者先去学懂VQVAE,再来看VQGAN。

VQGAN 核心思想

VQGAN的论文名为Taming Transformers for High-Resolution Image Synthesis,直译过来是「驯服Transformer模型以实现高清图像合成」。可以看出,该方法是在用Transformer生成图像。可是,为什么这个模型叫做VQGAN,是一个GAN呢?这是因为,VQGAN使用了两阶段的图像生成方法:

  • 训练时,先训练一个图像压缩模型(包括编码器和解码器两个子模型),再训练一个生成压缩图像的模型。
  • 生成时,先用第二个模型生成出一个压缩图像,再用第一个模型复原成真实图像。

其中,第一个图像压缩模型叫做VQGAN,第二个压缩图像生成模型是一个基于Transformer的模型。

为什么会有这种乍看起来非常麻烦的图像生成方法呢?要理解VQGAN的这种设计动机,有两条路线可以走。两条路线看待问题的角度不同,但实际上是在讲同一件事。

第一条路线是从Transformer入手。Transformer已经在文本生成领域大展身手。同时,Transformer也在视觉任务中开始崭露头角。相比擅长捕捉局部特征的CNN,Transformer的优势在于它能更好地融合图像的全局信息。可是,Transformer的自注意力操作开销太大,只能生成一些分辨率较低的图像。因此,作者认为,可以综合CNN和Transformer的优势,先用基于CNN的VQGAN把图像压缩成一个尺寸更小、信息更丰富的小图像,再用Transformer来生成小图像。

第二条路线是从VQVAE入手。VQVAE是VQGAN的前作,它有着和VQGAN一模一样两阶段图像生成方法。不同的是,VQVAE没有使用GAN结构,且其配套的压缩图像生成模型是基于CNN的。为提升VQVAE的生成效果,作者提出了两项改进策略:1) 图像压缩模型VQVAE仅使用了均方误差,压缩图像的复原结果较为模糊,可以把图像压缩模型换成GAN;2) 在生成压缩图片这个任务上,基于CNN的图像生成模型比不过Transformer,可以用Transformer代替原来的CNN。

第一条思路是作者在论文的引言中描述的,听起来比较高大上;而第二条思路是读者读过文章后能够自然总结出来的,相对来说比较清晰易懂。如果你已经理解了VQVAE,你能通过第二条思路瞬间弄懂VQGAN的原理。说难听点,VQGAN就是一个改进版的VQVAE。然而,VQGAN的改进非常有效,且使用了若干技巧来实现带约束(比如根据文字描述)的高清图像生成,有非常多地方值得学习。

在下文中,我将先补充VQVAE的背景以方便讨论,再介绍VQGAN论文的四大知识点:VQGAN的设计细节、生成压缩图像的Transformer的设计细节、带约束图像生成的实现方法、高清图像生成的实现方法。

VQVAE 背景知识补充

VQVAE的学习目标是用一个编码器把图像压缩成离散编码,再用一个解码器把图像尽可能地还原回原图像。

通俗来说,VQVAE就是把一幅真实图像压缩成一个小图像。这个小图像和真实图像有着一些相同的性质:小图像的取值和像素值(0-255的整数)一样,都是离散的;小图像依然是二维的,保留了某些空间信息。因此,VQVAE的示意图画成这样会更形象一些:

但小图像和真实图像有一个关键性的区别:与像素值不同,小图像的离散取值之间没有关联。真实图像的像素值其实是一个连续颜色的离散采样,相邻的颜色值也更加相似。比如颜色254和颜色253和颜色255比较相似。而小图像的取值之间是没有关联的,你不能说编码为1与编码为0和编码为2比较相似。由于神经网络不能很好地处理这种离散量,在实际实现中,编码并不是以整数表示的,而是以类似于NLP中的嵌入向量的形式表示的。VAE使用了嵌入空间(又称codebook)来完成整数序号到向量的转换。

为了让任意一个编码器输出向量都变成一个固定的嵌入向量,VQVAE采取了一种离散化策略:把每个输出向量$z_e(x)$替换成嵌入空间中最近的那个向量$z_q(x)$。$z_e(x)$的离散编码就是$z_q(x)$在嵌入空间的下标。这个过程和把254.9的输出颜色值离散化成255的整数颜色值的原理类似。

VQVAE的损失函数由两部分组成:重建误差和嵌入空间误差。

其中,重建误差就是输入和输出之间的均方误差。

嵌入空间误差为解码器输出向量$z_e(x)$和它在嵌入空间对应向量$z_q(x)$的均方误差。

作者在误差中还使用了一种「停止梯度」的技巧。这个技巧在VQGAN中被完全保留,此处就不过多介绍了。

图像压缩模型 VQGAN

回顾了VQVAE的背景知识后,我们来正式认识VQGAN的几个创新点。第一点,图像压缩模型VQVAE被改进成了VQGAN。

一般VAE重建出来出来的图像都会比较模糊。这是因为VAE只使用了均方误差,而均方误差只能保证像素值尽可能接近,却不能保证图像的感知效果更加接近。为此,作者把GAN的一些方法引入VQVAE,改造出了VQGAN。

具体来说,VQGAN有两项改进。第一,作者用感知误差(perceptual loss)代替原来的均方误差作为VQGAN的重建误差。第二,作者引入了GAN的对抗训练机制,加入了一个基于图块的判别器,把GAN误差加入了总误差。

计算感知误差的方法如下:把两幅图像分别输入VGG,取出中间某几层卷积层的特征,计算特征图像之间的均方误差。如果你之前没学过相关知识,请搜索”perceptual loss”。

基于图块的判别器,即判别器不为整幅图输出一个真或假的判断结果,而是把图像拆成若干图块,分别输出每个图块的判断结果,再对所有图块的判断结果取一个均值。这只是GAN的一种改进策略而已,没有对GAN本身做太大的改动。如果你之前没学过相关知识,请搜索”PatchGAN”。

这样,总的误差可以写成:

其中,$\lambda$是控制两种误差比例的权重。作者在论文中使用了一个公式来自适应地设置$\lambda$。和普通的GAN一样,VQGAN的编码器、解码器(即生成器)、codebook会最小化误差,判别器会最大化误差。

用VQGAN代替VQVAE后,重建图片中的模糊纹理清晰了很多。

有了一个保真度高的图像压缩模型,我们可以进入下一步,训练一个生成压缩图像的模型。

基于 Transformer 的压缩图像生成模型

如前所述,经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。因此,之前的VQVAE使用了一个能建模离散颜色的PixelCNN模型作为压缩图像生成模型。但PixelCNN的表现不够优秀。

恰好,功能强大的Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。

Transformer 随机生成句子的过程

为了让Transformer生成图像,我们可以把生成句子的一个个单词,变成生成压缩图像的一个个像素。但是,要让Transformer生成二维图像,还需要克服一个问题:在生成句子时,Transformer会先生成第一个单词,再根据第一个单词生成第二个单词,再根据第一、第二个单词生成第三个单词……。也就是说,Transformer每次会根据之前所有的单词来生成下一单词。而图像是二维数据,没有先后的概念,怎样让像素和文字一样有先后顺序呢?

VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第$i$步,Transformer会根据前$i - 1$个像素$s_{ < i}$生成第$i$个像素$s_i$,

带约束的图像生成

在生成新图像时,我们更希望模型能够根据我们的需求生成图像。比如,我们希望模型生成「一幅优美的风景画」,又或者希望模型在一幅草图的基础上作画。这些需求就是模型的约束。为了实现带约束的图像生成,一般的做法是先有一个无约束(输入是随机数)的图像生成模型,再在这个模型的基础上把一个表示约束的向量插入进图像生成的某一步。

把约束向量插入进模型的方法是需要设计的,插入约束向量的方法往往和模型架构有着密切关系。比如假设一个生成模型是U-Net架构,我们可以把约束向量和当前特征图拼接在一起,输入进U-Net的每一大层。

为了实现带约束的图像生成,VQGAN的作者再次借鉴了Transformer实现带约束文字生成的方法。许多自然语言处理任务都可以看成是带约束的文字生成。比如机器翻译,其实可以看成在给定一种语言的句子的前提下,让模型「随机」生成一个另一种语言的句子。比如要把「简要访问非洲」翻译成英语,我们可以对之前无约束文字生成的Transformer做一些修改。

也就是说,给定约束的句子$c$,在第$i$步,Transformer会根据前$i-1$个输出单词$s_{ < i}$以及$c$生成第$i$个单词$s_i$。表示约束的单词被添加到了所有输出之前,作为这次「随机生成」的额外输入。

上述方法并不是唯一的文字生成方法。这种文字生成方法被称为”decoder-only”。实际上,也有使用一个编码器来额外维护约束信息的文字生成方法。最早的Transformer就用到了带编码器的方法。

我们同样可以把这种思想搬到压缩图像生成里。比如对于MNIST数据集,我们希望模型只生成0~9这些数字中某一个数字的手写图像。也就是说,约束是类别信息,约束的取值是0~9。我们就可以把这个0~9的约束信息添加到Transformer的输入$s_{ < i}$之前,以实现由类别约束的图像生成。

但这种设计又会产生一个新的问题。假设约束条件不能简单地表示成整数,而是一些其他类型的数据,比如语义分割图像,那该怎么办呢?对于这种以图像形式表示的约束,作者的做法是,再训练另一个VQGAN,把约束图像压缩成另一套压缩图片。这一套压缩图片和生成图像的压缩图片有着不同的codebook,就像两种语言有着不同的单词一样。这样,约束图像也变成了一系列的整数,可以用之前的方法进行带约束图像生成了。

生成高清图像

由于Transformer注意力计算的开销很大,作者在所有配置中都只使用了$16 \times 16$的压缩图像,再增大压缩图像尺寸的话计算资源就不够了。而另一方面,每张图像在VQGAN中的压缩比例是有限的。如果图像压缩得过多,则VQGAN的重建质量就不够好了。因此,设边长压缩了$f$倍,则该方法一次能生成的图片的最大尺寸是$16f \times 16f$。在多项实验中,$f=16$的表现都较好。这样算下来,该方法一次只能生成$256 \times 256$的图片。这种尺寸的图片还称不上高清图片。

为了生成更大尺寸的图片,作者先训练好了一套能生成$256 \times 256$的图片的VQGAN+Transformer,再用了一种基于滑动窗口的采样机制来生成大图片。具体来说,作者把待生成图片划分成若干个$16\times16$像素的图块,每个图块对应压缩图像的一个像素。之后,在每一轮生成时,只有待生成图块周围的$16\times16$个图块($256\times256$个像素)会被输入进VQGAN和Transformer,由Transformer生成一个新的压缩图像像素,再把该压缩图像像素解码成图块。(在下面的示意图中,每个方块是一个图块,transformer的输入是$3\times3$个图块)

这个滑动窗口算法不是那么好理解,需要多想一下才能理解它的具体做法。在理解这个算法时,你可能会有这样的问题:上面的示意图中,待生成像素有的时候在最左边,有的时候在中间,有的时候在右边,每次约束它的像素都不一样。这么复杂的约束逻辑怎么编写?其实,Transformer自动保证了每个像素只会由之前的像素约束,而看不到后面的像素。因此,在实现时,只需要把待生成像素框起来,直接用Transformer预测待生成像素即可,不需要编写额外的约束逻辑。

如果你没有学过Transformer的话,理解这部分会有点困难。Transformer可以根据第1~k-1个像素并行地生成第2~k个像素,且保证生成每个像素时不会偷看到后面像素的信息。因此,假设我们要生成第i个像素,其实是预测了所有第2~k个像素的结果,再取出第i个结果,填回待生成图像。

由于论文篇幅有限,作者没有对滑动窗口机制做过多的介绍,也没有讲带约束的滑动窗口是怎么实现的。如果你在理解这一部分时碰到了问题,不用担心,这很正常。稍后我们会在代码阅读章节彻底理解滑动窗口的实现方法。我也是看了代码才看懂此处的做法。

作者在论文中解释了为什么用滑动窗口生成高清图像是合理的。作者先是讨论了两种情况,只要满足这两种情况中的任意一种,拿滑动窗口生成图像就是合理的。第一种情况是数据集的统计规律是几乎空间不变,也就是说训练集图片每$256\times256$个像素的统计规律是类似的。这和我们拿$3\times3$卷积卷图像是因为图像每$3\times3$个像素的统计规律类似的原理是一样的。第二种情况是有空间上的约束信息。比如之前提到的用语义分割图来指导图像生成。由于语义分割也是一张图片,它给每个待生成像素都提供了额外信息。这样,哪怕是用滑动窗口,在局部语义的指导下,模型也足以生成图像了。

若是两种情况都不满足呢?比如在对齐的人脸数据集上做无约束生成。在对齐的人脸数据集里,每张图片中人的五官所在的坐标是差不多的,图片的空间不变性不满足;做无约束生成,自然也没有额外的空间信息。在这种情况下,我们可以人为地添加一个坐标约束,即从左到右、从上到下地给每个像素标一个序号,把每个滑动窗口里的坐标序号做为约束。有了坐标约束后,就还原成了上面的第二种情况,每个像素有了额外的空间信息,基于滑动窗口的方法依然可行。

学完了论文的四大知识点,我们知道VQGAN是怎么根据约束生成高清图像的了。接下来,我们来看看论文的实验部分,看看作者是怎么证明方法的有效性的。

实验

在实验部分,作者先是分别验证了基于Transformer的压缩图像生成模型较以往模型的优越性(4.1节)、VQGAN较以往模型的优越性(4.4节末尾)、使用VQGAN做图像压缩的必要性及相关消融实验(4.3节),再把整个生成方法综合起来,在多项图像生成任务上与以往的图像生成模型做定量对比(4.4节),最后展示了该方法惊艳的带约束生成效果(4.2节)。

在论文4.1节中,作者验证了基于Transformer的压缩图像生成模型的有效性。之前,压缩图像都是使用能输出离散分布的PixelCNN系列模型来生成的。PixelCNN系列的最强模型是PixelSNAIL。为确保公平,作者对比了相同训练时间、相同训练步数下两个网络在不同训练集下的负对数似然(NLL)指标。结果表明,基于Transformer的模型确实训练得更快。

对于直接能建模离散分布的模型来说,NLL就是交叉熵损失函数。

在论文4.4节末尾,作者将VQGAN和之前的图像压缩模型对比,验证了引入感知误差和GAN结构的有效性。作者汇报了各模型重建图像集与原数据集(ImageNet的训练集和验证集)的FID(指标FID是越低越好)。同时,结果也说明,增大codebook的尺寸或者编码种类都能提升重建效果。

在论文4.3节中,作者验证了使用VQGAN的必要性。作者训了两个模型,一个直接让Transformer做真实图像生成,一个用VQGAN把图像边长压缩2倍,再用Transformer生成压缩图像。经比较,使用了VQGAN后,图像生成速度快了10多倍,且图像生成效果也有所提升。

另外,作者还做了有关图像边长压缩比例$f$的消融实验。作者固定让Transformer生成$16 \times 16$的压缩图片,即每次训练时用到的图像尺寸都是$16f \times 16f$。之后,作者训练训练了不同$f$下的模型,用各个模型来生成图片。结果显示$f=16$时效果最好。这是因为,在固定Transformer的生成分辨率的前提下,$f$越小,Transformer的感受野越小。如果Transformer的感受野过小,就学习不到足够的信息。

在论文4.4节中,作者探究了VQGAN+Transformer在多项基准测试(benchmark)上的结果。

首先是语义图像合成(根据语义分割图像来生成)任务。本文的这套方法还不错。

接着是人脸生成任务。这套方法表现还行,但还是比不过专精于某一任务的GAN。

作者还比较了各模型在ImageNet上的生成结果。这一比较的数据量较多,欢迎大家自行阅读原论文。

在论文4.2节中,作者展示了多才多艺的VQGAN+Transformer在各种约束下的图像生成结果。这些图像都是按照默认配置生成的,大小为$256\times256$。

作者还展示了使用了滑动窗口算法后,模型生成的不同分辨率的图像。

本文开头的那张高清图片也来自论文。

总结

VQGAN是一个改进版的VQVAE,它将感知误差和GAN引入了图像压缩模型,把压缩图像生成模型替换成了更强大的Transformer。相比纯种的GAN(如StyleGAN),VQGAN的强大之处在于它支持带约束的高清图像生成。VQGAN借助NLP中”decoder-only”策略实现了带约束图像生成,并使用滑动窗口机制实现了高清图像生成。虽然在某些特定任务上VQGAN还是落后于其他GAN,但VQGAN的泛化性和灵活性都要比纯种GAN要强。它的这些潜力直接促成了Stable Diffusion的诞生。

如果你是读完了VQVAE再来读的VQGAN,为了完全理解VQGAN,你只需要掌握本文提到的4个知识点:VQVAE到VQGAN的改进方法、使用Transformer做图像生成的方法、使用”decoder-only”策略做带约束图像生成的方法、用滑动滑动窗口生成任意尺寸的图片的思想。

代码阅读

在代码阅读章节中,我将先简略介绍官方源码的项目结构以方便大家学习,再介绍代码中的几处核心代码。具体来说,我会介绍模型是如何组织配置文件的、模型的定义代码在哪、训练代码在哪、采样代码在哪,同时我会主要分析VQGAN的结构、Transformer的结构、损失函数、滑动窗口采样算法这几部分的代码。

官方源码地址:https://github.com/CompVis/taming-transformers。

官方的Git仓库里有很多很大的图片,且git记录里还藏了一些很大的数据,整个Git仓库非常大。如果你的网络不好,建议以zip形式下载仓库,或者只把代码部分下载下来。

项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
├─assets
├─configs
├─scripts
└─taming
├─data
│ └─conditional_builder
├─models
└─modules
├─diffusionmodules
├─discriminator
├─losses
├─misc
├─transformer
└─vqvae

configs目录下存放的是模型配置文件。VQGAN和Transformer的模型配置是分开来放的。每个模型配置文件都会指向一个Python模型类,比如taming.models.vqgan.VQModel,配置里的参数就是模型类的初始化参数。我们可用通过阅读配置文件找到模型的定义位置。

运行脚本包括根目录下的main.pyscripts文件夹下的脚本。main.py是用于训练的。scripts文件夹下有各种采样脚本和数据集可视化脚本。

taming是源代码的主目录。其data子文件夹下放置了各数据集的预处理代码,models放置了VQGAN和Transformer PyTorch模型的定义代码,modules则放置了模型中用到的模块,主要包括VQGAN编码解码模块(diffusionmodules)、判别器模块(discriminator)、误差模块(losses)、Transformer模块(transformer)、codebook模块(vqvae)。

VQGAN 模型结构

打开configs\faceshq_vqgan.yaml,我们能够找到高清人脸生成任务使用的VQGAN模型配置。我们来学习一下这个模型的定义方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
...

lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
...

从配置文件的target字段中,我们知道VQGAN定义在模块taming.models.vqgan.VQModel中。我们可以打开taming\models\vqgan.py这个文件,查看其中VQModel类的代码。

首先先看一下初始化函数。初始化函数主要是初始化了encoderdecoderlossquantize这几个模块,我们可以从文件开头的import语句中找到这几个模块的定义位置。不过,先不急,我们来继续看一下模型的前向传播函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer

class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap, sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.image_key = image_key
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor

模型的前向传播逻辑非常清晰。self.encoder可以把一张图片变为特征,self.decoder可以把特征变回图片。self.quant_convpost_quant_conv则分别完成了编码器到codebook、codebook到解码器的通道数转换。self.quantize实现了VQVAE和VQGAN中那个找codebook里的最近邻、替换成最近邻的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info

def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

def forward(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff

接下来,我们再看一看VQGAN的各个模块的定义。编码器和解码器的定义都可以在taming\modules\diffusionmodules\model.py里找到。VQGAN使用的编码器和解码器基于DDPM论文中的U-Net架构(而此架构又可以追溯到PixelCNN++的模型架构)。相比于最经典的U-Net,此U-Net每一层由若干个残差块和若干个自注意力块构成。为了把这个U-Net用到VQGAN里,U-Net的下采样部分和上采样部分被拆开,分别做成了VQGAN的编码器和解码器。

此处代码过长,我就只贴出部分关键代码了。以下是编码器的__init__函数和forward函数的关键代码。self.down存储了U-Net各层的模块。对于第i层,down[i].block是所有残差块,down[i].attn是所有自注意力块,down[i].downsample是下采样操作。它们在forward里会被依次调用。解码器的结构与之类似,只不过下采样变成了上采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, **ignore_kwargs):
super().__init__()
...
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)

...


def forward(self, x):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
...

return h

之后,我们再看看离散化层的代码,即把编码器的输出变成codebook里的嵌入的实现代码。作者在taming\modules\vqvae\quantize.py中提供了VQVAE原版的离散化操作以及若干个改进过的离散化操作。我们就来看一下原版的离散化模块VectorQuantizer是怎么实现的。

离散化模块的初始化非常简洁,主要是初始化了一个嵌入层。

1
2
3
4
5
6
7
8
9
10
class VectorQuantizer(nn.Module):
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta

self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

在前向传播时,作者先是算出了编码器输出z和所有嵌入的距离d,再用argmin算出了最近邻嵌入的下标min_encodings,最后根据下标取出解码器输入z_q。同时,该函数还计算了其他几个可能用到的量,比如和codebook有关的误差 loss。注意,在计算lossz_q时,作者都使用到了停止梯度算子(.detach())。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def forward(self, z):
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())

## could possible replace this here
# #\start...
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)

min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)

z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
#.........\end


# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
torch.mean((z_q - z.detach()) ** 2)

# preserve gradients
z_q = z + (z_q - z).detach()

# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()

return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

VQGAN的三个主要模块已经看完了。最后,我们来看一下误差的定义。误差的定义在taming\modules\losses\vqperceptual.pyVQLPIPSWithDiscriminator类里。误差类名里的LPIPS(Learned Perceptual Image Patch Similarity,学习感知图像块相似度)就是感知误差的全称,”WithDiscriminator”表示误差是带了判定器误差的。我们来把这两类误差分别看一下。

说实话,这个误差模块乱得一塌糊涂,一边自己在算误差,一边又维护了codebook误差和重建误差的权重,最后会把自己维护的两个误差和其他误差合在一起输出。功能全部耦合在一起。我们就跳过这个类的实现细节,主要关注self.perceptual_lossself.discriminator是怎么调用其他模块的。

1
2
3
4
5
6
7
8
9
10
from taming.modules.losses.lpips import LPIPS
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init

class VQLPIPSWithDiscriminator(nn.Module):
def __init__(self, ...):
super().__init__()

self.perceptual_loss = LPIPS().eval()

self.discriminator = NLayerDiscriminator...

感知误差模块在taming\modules\losses\vqperceptual.py文件里。这个文件来自GitHub项目 PerceptualSimilarity。

感知误差可以简单地理解为两张图片在VGG中几个卷积层输出的误差的加权和。加权的权重是可以学习的。作者使用的是已经学习好的感知误差。感知误差的初始化函数如下。其中,self.lin0等模块就是算权重的模块,self.net是VGG。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False

在算误差时,先是把图像inputtarget都输入进VGG,获取各层输出outs0, outs1,再求出两个图像的输出的均方误差diffs,最后用lins给各层误差加权,求和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val

GAN的判别器写在taming\modules\discriminator\model.py文件里。这个文件来自GitHub上的 pytorch-CycleGAN-and-pix2pix 项目。这个判别器非常简单,就是一个全卷积网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d

kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]

sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)

def forward(self, input):
"""Standard forward."""
return self.main(input)

Transformer 模型结构

此方法使用的Transformer是GPT2。我们先看一下该项目封装Transformer的模型类taming.models.cond_transformer.Net2NetTransformer,再稍微看一下GPT类taming.modules.transformer.mingpt.GPT的具体实现。

Net2NetTransformer主要是实现了论文中提到的带约束生成。它会把输入x和约束c分别用一个VQGAN转成压缩图像,把图像压扁成一维,再调用GPT。我们来看一下这个类的主要内容。

初始化函数主要是初始化了输入图像的VQGAN self.first_stage_model、约束图像的VQGAN self.cond_stage_model、Transformer self.transformer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Net2NetTransformer(pl.LightningModule):
def __init__(self,
transformer_config,
first_stage_config,
cond_stage_config,
permuter_config=None,
ckpt_path=None,
ignore_keys=[],
first_stage_key="image",
cond_stage_key="depth",
downsample_cond_size=-1,
pkeep=1.0,
sos_token=0,
unconditional=False,
):
super().__init__()
self.be_unconditional = unconditional
self.sos_token = sos_token
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.init_first_stage_from_ckpt(first_stage_config)
self.init_cond_stage_from_ckpt(cond_stage_config)
if permuter_config is None:
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
self.permuter = instantiate_from_config(config=permuter_config)
self.transformer = instantiate_from_config(config=transformer_config)

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.downsample_cond_size = downsample_cond_size
self.pkeep = pkeep

def init_first_stage_from_ckpt(self, config):
model = instantiate_from_config(config)
model = model.eval()
model.train = disabled_train
self.first_stage_model = model

def init_cond_stage_from_ckpt(self, config):
...
self.cond_stage_model = ...

模型的前向传播函数如下。一开始,函数调用encode_to_zencode_to_c,根据self.cond_stage_modelself.first_stage_model把约束图像和输入图像编码成压扁至一维的压缩图像。之后函数做了一个类似Dropout的操作,根据self.pkeep随机替换掉约束编码。最后,函数把约束编码和输入编码拼接起来,使用通常方法调用Transformer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def forward(self, x, c):
# one step to produce the logits
_, z_indices = self.encode_to_z(x)
_, c_indices = self.encode_to_c(c)

if self.training and self.pkeep < 1.0:
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
device=z_indices.device))
mask = mask.round().to(dtype=torch.int64)
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
a_indices = mask*z_indices+(1-mask)*r_indices
else:
a_indices = z_indices

cz_indices = torch.cat((c_indices, a_indices), dim=1)

# target includes all sequence elements (no need to handle first one
# differently because we are conditioning)
target = z_indices
# make the prediction
logits, _ = self.transformer(cz_indices[:, :-1])
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
logits = logits[:, c_indices.shape[1]-1:]

return logits, target

GPT2的结构不是本文的重点,我们就快速把模型结构过一遍了。GPT2的模型定义在taming.modules.transformer.mingpt.GPT里。GPT2的结构并不复杂,就是一个只有解码器的Transformer。前向传播时,数据先通过嵌入层self.tok_emb,再经过若干个Transformer模块self.blocks,最后过一个LayerNorm层self.ln_f和线性层self.head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GPT(nn.Module):

def forward(self, idx, embeddings=None, targets=None):
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector

if embeddings is not None: # prepend explicit embeddings
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)

t = token_embeddings.shape[1]
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

return logits, loss

每个Transformer块就是非常经典的自注意力加全连接层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(), # nice
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)

def forward(self, x, layer_past=None, return_present=False):
# TODO: check that training still works
if return_present: assert not self.training
# layer past: tuple of length two with B, nh, T, hs
attn, present = self.attn(self.ln1(x), layer_past=layer_past)

x = x + attn
x = x + self.mlp(self.ln2(x))
if layer_past is not None or return_present:
return x, present
return x

基于滑动窗口的带约束图像生成

看完了所有模型的结构,我们最后来学习一下论文中没能详细介绍的滑动窗口算法。在scripts\taming-transformers.ipynb里有一个采样算法的最简实现,我们就来学习一下这份代码。

这份代码可以根据一幅语义分割图像来生成高清图像。一开始,代码会读入模型和语义分割图像。大致的代码为:

1
2
3
4
5
6
7
from taming.models.cond_transformer import Net2NetTransformer
model = Net2NetTransformer(**config.model.params)
from PIL import Image
import numpy as np
segmentation_path = "data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png"
segmentation = Image.open(segmentation_path)
...


之后,代码把约束图像用对应的VQGAN编码进压缩空间,得到c_indices。由于待生成图像为空,我们可以随便生成一个待生成图像的压缩图像z_indices,代码中使用了randint初始化待生成的压缩图像。

1
2
3
4
5
6
7
8
c_code, c_indices = model.encode_to_c(segmentation)
z_indices = torch.randint(codebook_size, z_indices_shape, device=model.device)

idx = z_indices
idx = idx.reshape(z_code_shape[0],z_code_shape[2],z_code_shape[3])

cidx = c_indices
cidx = cidx.reshape(c_code.shape[0],c_code.shape[2],c_code.shape[3])

最后就是最关键的滑动窗口采样部分了。我们先稍微浏览一遍代码,再详细地一行一行看过去。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
temperature = 1.0
top_k = 100

for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

一开始的temperaturetop_k是得到logit后的采样参数,和滑动窗口算法无关。
1
2
temperature = 1.0
top_k = 100

进入生成图像循环后,i, j分别表示压缩图像的竖索引和横索引,i_start, i_end, j_start, j_end是滑动窗口上下左右边界。

1
2
3
4
5
6
7
8
for i in range(0, z_code_shape[2]-0):
...
for j in range(0,z_code_shape[3]-0):
...
i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16

为了获取这四个滑动窗口的范围,代码用了若干条件语句计算待生成像素在滑动窗口里的相对位置local_i, local_j

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for i in range(0, z_code_shape[2]-0):
if i <= 8:
local_i = i
elif z_code_shape[2]-i < 8:
local_i = 16-(z_code_shape[2]-i)
else:
local_i = 8
for j in range(0,z_code_shape[3]-0):
if j <= 8:
local_j = j
elif z_code_shape[3]-j < 8:
local_j = 16-(z_code_shape[3]-j)
else:
local_j = 8

得到了滑动窗口的边界后,代码用滑动窗口从约束图像的压缩图像和待生成图像的压缩图像上各取出一个图块,并拼接起来。

1
2
3
4
5
patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)

之后,只需要把拼接的图块直接输入进Transformer,得到输出logits,再用local_i,local_j去输出图块的对应位置取出下一个压缩图像像素的概率分布,就可以随机生成下一个压缩图像像素了。如前文所述,Transformer类会把二维的图块压扁到一维,输入进GPT。同时,GPT会自动保证前面的像素看不到后面的像素,我们不需要人为地指定约束像素。这个地方的调用逻辑其实非常简单。

1
2
3
4
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(z_code_shape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]

最后只要从logits里采样,把采样出的压缩图像像素填入idx,就完成了一步生成。

1
2
3
4
5
6
7
logits = logits/temperature

if top_k is not None:
logits = model.top_k_logits(logits, top_k)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx[:,i,j] = torch.multinomial(probs, num_samples=1)

反复执行循环,就能将压缩图像生成完毕。最后将压缩图像过一遍VQGAN的解码器即可得到最终的生成图像。

1
2
x_sample = model.decode_to_img(idx, z_code_shape)
show_image(x_sample)

参考资料

VQGAN论文:https://arxiv.org/abs/2012.09841

VQGAN GitHub:https://github.com/CompVis/taming-transformers

如果你需要补充学习早期工作,欢迎阅读我之前的文章。

Transformer解读

PixelCNN解读

VQVAE解读

在这篇文章中,我将详细地介绍一个英中翻译 Transformer 的 PyTorch 实现。这篇文章会完整地展示一个深度学习项目的搭建过程,从数据集准备,到模型定义、训练。这篇文章不仅会讲解如何把 Transformer 的论文翻译成代码,还会讲清楚代码实现中的诸多细节,并分享我做实验时碰到的种种坑点。相信初学者能够从这篇文章中学到丰富的知识。

项目网址: https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/Transformer

如果你对 Transformer 的论文不熟,欢迎阅读我之前的文章:Attention Is All You Need (Transformer) 论文精读

数据集准备

我在 https://github.com/P3n9W31/transformer-pytorch 项目中找到了一个较小的中英翻译数据集。数据集只有几KB大小,中英词表只有10000左右,比较适合做Demo。如果要实现更加强大实用的模型,则需要换更大的数据集。但相应地,你要多花费更多的时间来训练。

我在代码仓库中提供了data_load.py文件。执行这个文件后,实验所需要的数据会自动下载到项目目录的data文件夹下。

该数据集由cn.txt, en.txt, cn.txt.vocab.tsv, en.txt.vocab.tsv这四个文件组成。前两个文件包含相互对应的中英文句子,其中中文已做好分词,英文全为小写且标点已被分割好。后两个文件是预处理好的词表。语料来自2000年左右的中国新闻,其第一条的中文及其翻译如下:

1
2
目前 粮食 出现 阶段性 过剩 , 恰好 可以 以 粮食 换 森林 、 换 草地 , 再造 西部 秀美 山川 。
the present food surplus can specifically serve the purpose of helping western china restore its woodlands , grasslands , and the beauty of its landscapes .

词表则统计了各个单词的出现频率。通过使用词表,我们能实现单词和序号的相互转换(比如中文里的5号对应“的”字,英文里的5号对应”the”)。词表的前四个单词是特殊字符,分别为填充字符、频率太少没有被加入词典的词语、句子开始字符、句子结束字符。

1
2
3
4
5
6
7
8
<PAD>	1000000000
<UNK> 1000000000
<S> 1000000000
</S> 1000000000
的 8461
是 2047
和 1836
在 1784
1
2
3
4
5
6
7
8
<PAD>	1000000000
<UNK> 1000000000
<S> 1000000000
</S> 1000000000
the 13680
and 6845
of 6259
to 4292

只要运行一遍data_load.py下好数据后,我们待会就能用load_train_data()来获取已经打成batch的训练数据,并用API获取cn2idx, idx2cn, en2idx, idx2en这四个描述中英文序号与单词转换的词典。我们会在之后的训练代码里见到它们的用法。

Transformer 模型

准备好数据后,接下来就要进入这个项目最重要的部分——Transformer 模型实现了。我将按照代码的执行顺序,从前往后,自底向上地介绍 Transformer 的各个模块:Positional Encoding, MultiHeadAttention, Encoder & Decoder, 最后介绍如何把各个模块拼到一起。在这个过程中,我还会着重介绍一个论文里没有提及,但是代码实现时非常重要的一个细节——<pad>字符的处理。

说实话,用 PyTorch 实现 Transformer 没有什么有变数的地方,大家的代码写得都差不多,我也是参考着别人的教程写的。但是,Transformer 的代码实现中有很多坑。大部分人只会云淡风轻地介绍一下最终的代码成品,不会去讲他们 debug 耗费了多少时间,哪些地方容易出错。而我会着重讲一下代码中的一些细节,以及我碰到过的问题。

Positional Encoding

模型一开始是一个 Embedding 层加一个 Positional Encoding。Embedding 在 PyTorch 里已经有实现了,且不是文章的创新点,我们就直接来看 Positional Encoding 的写法。

求 Positional Encoding,其实就是求一个二元函数的许多函数值构成的矩阵。对于二元函数$PE(pos, i)$,我们要求出$pos \in [0, seqlen - 1], i \in [0, d_{model} - 1]$时所有的函数值,其中,$seqlen$是该序列的长度,$d_{model}$是每一个词向量的长度。

理论上来说,每个句子的序列长度$seqlen$是不固定的。但是,我们可以提前预处理一个$seqlen$很大的 Positional Encoding 矩阵 。每次有句子输入进来,根据这个句子的序列长度,去预处理好的矩阵里取一小段出来即可。

这样,整个类的实现应该如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class PositionalEncoding(nn.Module):

def __init__(self, d_model: int, max_seq_len: int):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(1, max_seq_len, d_model)

self.register_buffer('pe', pe, False)

def forward(self, x: torch.Tensor):
n, seq_len, d_model = x.shape
pe: torch.Tensor = self.pe
assert seq_len <= pe.shape[1]
assert d_model == pe.shape[2]
rescaled_x = x * d_model**0.5
return rescaled_x + pe[:, 0:seq_len, :]

代码中有不少需要讲解的部分。首先,先看一下预处理好的矩阵pe是怎么在__init__中算出来的。pe可以很直接地用两层循环算出来。由于这段预处理代码只会执行一次,相对于冗长的训练时间,哪怕生成pe的代码性能差一点也没事。然而,作为一个编程高手,我准备秀一下如何用并行的方法求出pe

为了并行地求pe,我们要初始化一个二维网格,表示自变量$pos, i$。生成网格可以用下面的代码实现。(由于$i$要分奇偶讨论,$i$的个数是$\frac{d_{model}}{2}$)

1
2
3
i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)

torch.meshgrid用于生成网格。比如torch.meshgrid([0, 1], [0, 1])就可以生成[[(0, 0), (0, 1)], [(1, 0), (1, 1)]]这四个坐标构成的网格。不过,这个函数会把坐标的两个分量分别返回。比如:

1
2
3
i, j = torch.meshgrid([0, 1], [0, 1])
# i: [[0, 0], [1, 1]]
# j: [[0, 1], [0, 1]]

利用这个函数的返回结果,我们可以把pos, two_i套入论文的公式,并行地分别算出奇偶位置的 PE 值。

1
2
pe_2i = torch.sin(pos / 10000**(two_i / d_model))
pe_2i_1 = torch.cos(pos / 10000**(two_i / d_model))

有了奇偶处的值,现在的问题是怎么把它们优雅地拼到同一个维度上。我这里先把它们堆成了形状为seq_len, d_model/2, 2的一个张量,再把最后一维展平,就得到了最后的pe矩阵。这一操作等于新建一个seq_len, d_model形状的张量,再把奇偶位置处的值分别填入。

1
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(1, max_seq_len, d_model)

最后,要注意一点。只用 self.pe = pe 记录这个量是不够好的。我们最好用 self.register_buffer('pe', pe, False) 把这个量登记成 torch.nn.Module 的一个存储区(这一步会自动完成self.pe = pe)。这里涉及到 PyTorch 的一些知识了。

PyTorch 的 Module 会记录两类参数,一类是 parameter 可学习参数,另一类是 buffer 不可学习的参数。把变量登记成 buffer 的最大好处是,在使用 model.to(device) 把一个模型搬到另一个设备上时,所有 parameterbuffer 都会自动被搬过去。另外,bufferparameter 一样,也可以被记录到 state_dict 中,并保存到文件里。register_buffer 的第三个参数决定了是否将变量加入 state_dict。由于 pe 可以直接计算,不需要记录,可以把这个参数设成 False

预处理好 pe 后,用起来就很方便了。每次读取输入的序列长度,从中取一段出来即可。

另外,Transformer 给嵌入层乘了个系数$\sqrt{d_{model}}$。为了方便起见,我把这个系数放到了 PositionalEncoding 类里面。

1
2
3
4
5
6
7
def forward(self, x: torch.Tensor):
n, seq_len, d_model = x.shape
pe: torch.Tensor = self.pe
assert seq_len <= pe.shape[1]
assert d_model == pe.shape[2]
rescaled_x = x * d_model**0.5
return rescaled_x + pe[:, 0:seq_len, :]

Scaled Dot-Product Attention

下一步是多头注意力层。为了实现多头注意力,我们先要实现 Transformer 里经典的注意力计算。而在讲注意力计算之前,我还要补充一下 Transformer 中有关 mask 的一些知识。

Transformer 里的 mask

Transformer 最大的特点就是能够并行训练。给定翻译好的第1~n个词语,它默认会并行地预测第2~(n+1)个下一个词语。为了模拟串行输出的情况,第$t$个词语不应该看到第$t+1$个词语之后的信息。

输入信息 输出
(y1, —, —, —) y2
(y1, y2, —, —) y3
(y1, y2, y3, —) y4
(y1, y2, y3, y4) y5

为了实现这一功能,Transformer 在Decoder里使用了掩码。掩码取1表示这个地方的数是有效的,取0表示这个地方的数是无效的。Decoder 里的这种掩码应该是一个上三角全1矩阵。

掩码是在注意力计算中生效的。对于掩码取0的区域,其softmax前的$QK^T$值取负无穷。这是因为,对于softmax

令$x_i=-\infty$可以让它在 softmax 的分母里不产生任何贡献。

以上是论文里提到的 mask,它用来模拟 Decoder 的串行推理。而在代码实现中,还有其他地方会产生 mask。在生成一个 batch 的数据时,要给句子填充 <pad>。这个特殊字符也没有实际意义,不应该对计算产生任何贡献。因此,有 <pad> 的地方的 mask 也应该为0。之后讲 Transformer 模型类时,我会介绍所有的 mask 该怎么生成,这里我们仅关注注意力计算是怎么用到 mask 的。

注意力计算

补充完了背景知识,我们来看注意力计算的实现代码。由于注意力计算没有任何的状态,因此它应该写成一个函数,而不是一个类。我们可以轻松地用 PyTorch 代码翻译注意力计算的公式。(注意,我这里的 mask 表示哪些地方要填负无穷,而不是像之前讲的表示哪些地方有效)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
MY_INF = 1e12

def attention(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
'''
Note: The dtype of mask must be bool
'''
# q shape: [n, heads, q_len, d_k]
# k shape: [n, heads, k_len, d_k]
# v shape: [n, heads, k_len, d_v]
assert q.shape[-1] == k.shape[-1]
d_k = k.shape[-1]
# tmp shape: [n, heads, q_len, k_len]
tmp = torch.matmul(q, k.transpose(-2, -1)) / d_k**0.5
if mask is not None:
tmp.masked_fill_(mask, -MY_INF)
tmp = F.softmax(tmp, -1)
# tmp shape: [n, heads, q_len, d_v]
tmp = torch.matmul(tmp, v)
return tmp

这里有一个很坑的地方。引入了 <pad> 带来的 mask 后,会产生一个新的问题:可能一整行数据都是失效的,softmax 用到的所有 $x_i$ 可能都是负无穷。

这个数是没有意义的。如果用torch.inf来表示无穷大,就会令exp(torch.inf)=0,最后 softmax 结果会出现 NaN,代码大概率是跑不通的。

但是,大多数 PyTorch Transformer 教程压根就没提这一点,而他们的代码又还是能够跑通。拿放大镜仔细对比了代码后,我发现,他们的无穷大用的不是 torch.inf,而是自己随手设的一个极大值。这样,exp(-MY_INF)得到的不再是0,而是一个极小值。softmax 的结果就会等于分母的项数,而不是 NaN,不会有数值计算上的错误。

Multi-Head Attention

有了注意力计算,就可以实现多头注意力层了。多头注意力层是有学习参数的,它应该写成一个类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class MultiHeadAttention(nn.Module):

def __init__(self, heads: int, d_model: int, dropout: float = 0.1):
super().__init__()

assert d_model % heads == 0
# dk == dv
self.d_k = d_model // heads
self.heads = heads
self.d_model = d_model
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
# batch should be same
assert q.shape[0] == k.shape[0]
assert q.shape[0] == v.shape[0]
# the sequence length of k and v should be aligned
assert k.shape[1] == v.shape[1]

n, q_len = q.shape[0:2]
n, k_len = k.shape[0:2]
q_ = self.q(q).reshape(n, q_len, self.heads, self.d_k).transpose(1, 2)
k_ = self.k(k).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)
v_ = self.v(v).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)

attention_res = attention(q_, k_, v_, mask)
concat_res = attention_res.transpose(1, 2).reshape(
n, q_len, self.d_model)
concat_res = self.dropout(concat_res)

output = self.out(concat_res)
return output

这段代码一处很灵性的地方。在 Transformer 的论文中,多头注意力是先把每个词的表示拆成$h$个头,再对每份做投影、注意力,最后拼接起来,再投影一次。其实,拆开与拼接操作是多余的。我们可以通过一些形状上的操作,等价地实现拆开与拼接,以提高运行效率。

具体来说,我们可以一开始就让所有头的数据经过同一个线性层。之后在做注意力之前把头和序列数这两维转置一下。这两步操作和拆开来做投影、注意力是等价的。做完了注意力操作之后,再把两个维度转置回来,这和拼接操作是等价的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def forward(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None):
# batch should be same
assert q.shape[0] == k.shape[0]
assert q.shape[0] == v.shape[0]
# the sequence length of k and v should be aligned
assert k.shape[1] == v.shape[1]

n, q_len = q.shape[0:2]
n, k_len = k.shape[0:2]
q_ = self.q(q).reshape(n, q_len, self.heads, self.d_k).transpose(1, 2)
k_ = self.k(k).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)
v_ = self.v(v).reshape(n, k_len, self.heads, self.d_k).transpose(1, 2)

attention_res = attention(q_, k_, v_, mask)
concat_res = attention_res.transpose(1, 2).reshape(
n, q_len, self.d_model)
concat_res = self.dropout(concat_res)

output = self.out(concat_res)
return output

前馈网络

前馈网络太简单了,两个线性层,没什么好说的。注意内部那个隐藏层的维度大小$d_{ff}$会比$d_{model}$更大一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
class FeedForward(nn.Module):

def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.layer1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.layer2 = nn.Linear(d_ff, d_model)

def forward(self, x):
x = self.layer1(x)
x = self.dropout(F.relu(x))
x = self.layer2(x)
return x

Encoder & Decoder

准备好一切组件后,就可以把模型一层一层搭起来了。先搭好每个 Encoder 层和 Decoder 层,再拼成 Encoder 和 Decoder。

Encoder 层和 Decoder 层的结构与论文中的描述一致,且每个子层后面都有一个 dropout,和上一层之间使用了残差连接。归一化的方法是 LayerNorm。顺带一提,不仅是这些层,前面很多子层的计算中都加入了 dropout。

再提一句 mask。由于 encoder 和 decoder 的输入不同,它们的填充情况不同,产生的 mask 也不同。后文会展示这些 mask 的生成方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class EncoderLayer(nn.Module):

def __init__(self,
heads: int,
d_model: int,
d_ff: int,
dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(heads, d_model, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

def forward(self, x, src_mask: Optional[torch.Tensor] = None):
tmp = self.self_attention(x, x, x, src_mask)
tmp = self.dropout1(tmp)
x = self.norm1(x + tmp)
tmp = self.ffn(x)
tmp = self.dropout2(tmp)
x = self.norm2(x + tmp)
return x


class DecoderLayer(nn.Module):

def __init__(self,
heads: int,
d_model: int,
d_ff: int,
dropout: float = 0.1):
super().__init__()
self.self_attention = MultiHeadAttention(heads, d_model, dropout)
self.attention = MultiHeadAttention(heads, d_model, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

def forward(self,
x,
encoder_kv: torch.Tensor,
dst_mask: Optional[torch.Tensor] = None,
src_dst_mask: Optional[torch.Tensor] = None):
tmp = self.self_attention(x, x, x, dst_mask)
tmp = self.dropout1(tmp)
x = self.norm1(x + tmp)
tmp = self.attention(x, encoder_kv, encoder_kv, src_dst_mask)
tmp = self.dropout2(tmp)
x = self.norm2(x + tmp)
tmp = self.ffn(x)
tmp = self.dropout3(tmp)
x = self.norm3(x + tmp)
return x

Encoder 和 Decoder 就是在所有子层前面加了一个嵌入层、一个位置编码,再把多个子层堆起来了而已,其他输入输出照搬即可。注意,我们可以给嵌入层输入pad_idx参数,让<pad>的计算不对梯度产生贡献。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Encoder(nn.Module):

def __init__(self,
vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 120):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
self.pe = PositionalEncoding(d_model, max_seq_len)
self.layers = []
for i in range(n_layers):
self.layers.append(EncoderLayer(heads, d_model, d_ff, dropout))
self.layers = nn.ModuleList(self.layers)
self.dropout = nn.Dropout(dropout)

def forward(self, x, src_mask: Optional[torch.Tensor] = None):
x = self.embedding(x)
x = self.pe(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, src_mask)
return x


class Decoder(nn.Module):

def __init__(self,
vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 120):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, pad_idx)
self.pe = PositionalEncoding(d_model, max_seq_len)
self.layers = []
for i in range(n_layers):
self.layers.append(DecoderLayer(heads, d_model, d_ff, dropout))
self.layers = nn.Sequential(*self.layers)
self.dropout = nn.Dropout(dropout)

def forward(self,
x,
encoder_kv,
dst_mask: Optional[torch.Tensor] = None,
src_dst_mask: Optional[torch.Tensor] = None):
x = self.embedding(x)
x = self.pe(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, encoder_kv, dst_mask, src_dst_mask)
return x

Transformer 类

终于,激动人心的时候到来了。我们要把各个子模块组成变形金刚(Transformer)了。先过一遍所有的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Transformer(nn.Module):

def __init__(self,
src_vocab_size: int,
dst_vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 200):
super().__init__()
self.encoder = Encoder(src_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.decoder = Decoder(dst_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.pad_idx = pad_idx
self.output_layer = nn.Linear(d_model, dst_vocab_size)

def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):
# q_pad shape: [n, q_len]
# k_pad shape: [n, k_len]
# q_pad k_pad dtype: bool
assert q_pad.device == k_pad.device
n, q_len = q_pad.shape
n, k_len = k_pad.shape

mask_shape = (n, 1, q_len, k_len)
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)
mask = mask.to(q_pad.device)
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1
mask = mask.to(torch.bool)
return mask

def forward(self, x, y):

src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)
encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

我们一点一点来看。先看初始化函数。初始化函数的输入其实就是 Transformer 模型的超参数。总结一下,Transformer 应该有这些超参数:

  • d_model 模型中大多数词向量表示的维度大小
  • d_ff 前馈网络隐藏层维度大小
  • n_layers 堆叠的 Encoder & Decoder 层数
  • head 多头注意力的头数
  • dropout Dropout 的几率

另外,为了构建嵌入层,要知道源语言、目标语言的词典大小,并且提供pad_idx。为了预处理位置编码,需要提前知道一个最大序列长度。

照着子模块的初始化参数表,把参数归纳到__init__的参数表里即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def __init__(self,
src_vocab_size: int,
dst_vocab_size: int,
pad_idx: int,
d_model: int,
d_ff: int,
n_layers: int,
heads: int,
dropout: float = 0.1,
max_seq_len: int = 200):
super().__init__()
self.encoder = Encoder(src_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.decoder = Decoder(dst_vocab_size, pad_idx, d_model, d_ff,
n_layers, heads, dropout, max_seq_len)
self.pad_idx = pad_idx
self.output_layer = nn.Linear(d_model, dst_vocab_size)

再看一下 forward 函数。forward先预处理好了所有的 mask,再逐步执行 Transformer 的计算:先是通过 Encoder 获得源语言的中间表示encoder_kv,再把它和目标语言y的输入一起传入 Decoder,最后经过线性层输出结果res。由于 PyTorch 的交叉熵损失函数自带了 softmax 操作,这里不需要多此一举。

Transformer 论文提到,softmax 前的那个线性层可以和嵌入层共享权重。也就是说,嵌入和输出前的线性层分别完成了词序号到词嵌入的正反映射,两个操作应该是互逆的。但是,词嵌入矩阵不是一个方阵,它根本不能求逆矩阵。我想破头也没想清楚是怎么让线性层可以和嵌入层共享权重的。网上的所有实现都没有对这个细节多加介绍,只是新建了一个线性层。我也照做了。

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x, y):

src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)
encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

等了很久,现在可以来仔细看一看 mask 的生成方法了。回忆一下,表示该字符是否有效的 mask 有两个来源。第一个是论文里提到的,用于模拟串行推理的 mask;另一个是填充操作的空白字符引入的 mask。generate_mask 用于生成这些 mask。

generate_mask 的输入有 query 句子和 key 句子的 pad mask q_pad, k_pad,它们的形状为[n, seq_len]。若某处为 True,则表示这个地方的字符是<pad>。对于自注意力,query 和 key 都是一样的;而在 Decoder 的第二个多头注意力层中,query 来自目标语言,key 来自源语言。with_left_mask 表示是不是要加入 Decoder 里面的模拟串行推理的 mask,它会在掩码自注意力里用到。

1
2
3
4
def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):

一开始,先取好维度信息,定好张量的形状。在注意力操作中,softmax 前的那个量的形状是 [n, heads, q_len, k_len],表示每一批每一个头的每一个query对每个key之间的相似度。每一个头的mask是一样的。因此,除heads维可以广播外,mask 的形状应和它一样。

1
mask_shape = (n, 1, q_len, k_len)

再新建一个表示最终 mask 的张量。如果不用 Decoder 的那种 mask,就生成一个全零的张量;否则,生成一个上三角为0,其余地方为1的张量。注意,在我的代码中,mask 为 True 或1就表示这个地方需要填负无穷。

1
2
3
4
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)

最后,把有 <pad> 的地方也标记一下。从mask的形状[n, 1, q_len, k_len]可以知道,q_pad 表示哪些行是无效的,k_pad 表示哪些列是无效的。如果query句子的第i个字符是<pad>,则应该令mask[:, :, i, :] = 1; 如果key句子的第j个字符是<pad>,则应该令mask[:, :, :, j] = 1

下面的代码利用了PyTorch的取下标机制,直接并行地完成了mask赋值。

1
2
3
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def generate_mask(self,
q_pad: torch.Tensor,
k_pad: torch.Tensor,
with_left_mask: bool = False):
# q_pad shape: [n, q_len]
# k_pad shape: [n, k_len]
# q_pad k_pad dtype: bool
assert q_pad.device == k_pad.device
n, q_len = q_pad.shape
n, k_len = k_pad.shape

mask_shape = (n, 1, q_len, k_len)
if with_left_mask:
mask = 1 - torch.tril(torch.ones(mask_shape))
else:
mask = torch.zeros(mask_shape)
mask = mask.to(q_pad.device)
for i in range(n):
mask[i, :, q_pad[i], :] = 1
mask[i, :, :, k_pad[i]] = 1
mask = mask.to(torch.bool)
return mask

看完了mask的生成方法后,我们回到前一步,看看mask会在哪些地方被调用。

在 Transformer 中,有三类多头注意力层,它们的 mask 也不同。Encoder 的多头注意力层的 query 和 key 都来自源语言;Decoder 的第一个多头注意力层的 query 和 key 都来自目标语言;Decoder 的第二个多头注意力层的 query 来自目标语言, key 来自源语言。另外,Decoder 的第一个多头注意力层要加串行推理的那个 mask。按照上述描述生成mask即可。

1
2
3
4
5
6
7
8
9
10
11
def forward(self, x, y):
src_pad_mask = x == self.pad_idx
dst_pad_mask = y == self.pad_idx
src_mask = self.generate_mask(src_pad_mask, src_pad_mask, False)
dst_mask = self.generate_mask(dst_pad_mask, dst_pad_mask, True)
src_dst_mask = self.generate_mask(dst_pad_mask, src_pad_mask, False)

encoder_kv = self.encoder(x, src_mask)
res = self.decoder(y, encoder_kv, dst_mask, src_dst_mask)
res = self.output_layer(res)
return res

到此,Transfomer 模型总算编写完成了。

这里再帮大家排一个坑。PyTorch的官方Transformer中使用了下面的参数初始化方式。但是,实际测试后,不知道为什么,我发现使用这种初始化会让模型训不起来。

1
2
3
4
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

我去翻了翻PyTorch的Transformer示例,发现官方的示例根本没用到Transformer,而是用子模块nn.TransformerDecoder, nn.TransformerEncoder自己搭了一个新的Transformer。这些子模块其实都有自己的init_weights方法。看来官方都信不过自己的Transformer,这个Transformer类的初始化方法就有问题。

在我们的代码中,我们不必手动对参数初始化。PyTorch对每个线性层默认的参数初始化方式就够好了。

训练

准备好了模型、数据集后,剩下的工作非常惬意,只要随便调用一下就行了。训练的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import time

from dldemos.Transformer.data_load import (get_batch_indices, load_cn_vocab,
load_en_vocab, load_train_data,
maxlen)
from dldemos.Transformer.model import Transformer

# Config
batch_size = 64
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60
PAD_ID = 0


def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()
# X: en
# Y: cn
Y, X = load_train_data()

print_interval = 100

model = Transformer(len(en2idx), len(cn2idx), PAD_ID, d_model, d_ff,
n_layers, heads, dropout_rate, maxlen)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)

citerion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
tic = time.time()
cnter = 0
for epoch in range(n_epochs):
for index, _ in get_batch_indices(len(X), batch_size):
x_batch = torch.LongTensor(X[index]).to(device)
y_batch = torch.LongTensor(Y[index]).to(device)
y_input = y_batch[:, :-1]
y_label = y_batch[:, 1:]
y_hat = model(x_batch, y_input)

y_label_mask = y_label != PAD_ID
preds = torch.argmax(y_hat, -1)
correct = preds == y_label
acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask)

n, seq_len = y_label.shape
y_hat = torch.reshape(y_hat, (n * seq_len, -1))
y_label = torch.reshape(y_label, (n * seq_len, ))
loss = citerion(y_hat, y_label)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()

if cnter % print_interval == 0:
toc = time.time()
interval = toc - tic
minutes = int(interval // 60)
seconds = int(interval % 60)
print(f'{cnter:08d} {minutes:02d}:{seconds:02d}'
f' loss: {loss.item()} acc: {acc.item()}')
cnter += 1

model_path = 'dldemos/Transformer/model.pth'
torch.save(model.state_dict(), model_path)

print(f'Model saved to {model_path}')


if __name__ == '__main__':
main()

所有的超参数都写在代码开头。在模型结构上,我使用了和原论文一样的超参数。

1
2
3
4
5
6
7
8
9
10
# Config
batch_size = 64
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60
PAD_ID = 0

之后,进入主函数。一开始,我们调用load_data.py提供的API,获取中英文序号到单词的转换词典,并获取已经打包好的训练数据。

1
2
3
4
5
6
7
def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()
# X: en
# Y: cn
Y, X = load_train_data()

接着,我们用参数初始化好要用到的对象,比如模型、优化器、损失函数。

1
2
3
4
5
6
7
8
9
10
11
print_interval = 100

model = Transformer(len(en2idx), len(cn2idx), PAD_ID, d_model, d_ff,
n_layers, heads, dropout_rate, maxlen)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr)

citerion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
tic = time.time()
cnter = 0

再然后,进入训练循环。我们从X, Y里取出源语言和目标语言的序号数组,输入进模型里。别忘了,Transformer可以并行训练。我们给模型输入目标语言前n-1个单词,用第2到第n个单词作为监督标签。

1
2
3
4
5
6
7
for epoch in range(n_epochs):
for index, _ in get_batch_indices(len(X), batch_size):
x_batch = torch.LongTensor(X[index]).to(device)
y_batch = torch.LongTensor(Y[index]).to(device)
y_input = y_batch[:, :-1]
y_label = y_batch[:, 1:]
y_hat = model(x_batch, y_input)

得到模型的预测y_hat后,我们可以把输出概率分布中概率最大的那个单词作为模型给出的预测单词,算一个单词预测准确率。当然,我们要排除掉<pad>的影响。

1
2
3
4
y_label_mask = y_label != PAD_ID
preds = torch.argmax(y_hat, -1)
correct = preds == y_label
acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask)

我们最后算一下loss,并执行梯度下降,训练代码就写完了。为了让训练更稳定,不出现梯度过大的情况,我们可以用torch.nn.utils.clip_grad_norm_(model.parameters(), 1)裁剪梯度。

1
2
3
4
5
6
7
8
9
n, seq_len = y_label.shape
y_hat = torch.reshape(y_hat, (n * seq_len, -1))
y_label = torch.reshape(y_label, (n * seq_len, ))
loss = citerion(y_hat, y_label)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()

实验

在本项目的实验中,使用单卡3090,约10分钟就能完成训练。最终的训练准确率可以到达90%以上。

1
00006300 12:12 loss: 0.43494755029678345 acc: 0.9049844145774841

该数据集没有提供测试集(原仓库里的测试集来自训练集,这显然不合理)。且由于词表太小,不太好构建测试集。因此,我没有编写从测试集里生成句子并算BLEU score的代码,而是写了一份翻译给定句子的代码。要编写测试BLUE score的代码,只需要把翻译任意句子的代码改个输入,加一个求BLEU score的函数即可。这份翻译任意句子的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch

from dldemos.Transformer.data_load import (load_cn_vocab, load_en_vocab,
idx_to_sentence, maxlen)
from dldemos.Transformer.model import Transformer

# Config
batch_size = 1
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60

PAD_ID = 0


def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()

model = Transformer(len(en2idx), len(cn2idx), 0, d_model, d_ff, n_layers,
heads, dropout_rate, maxlen)
model.to(device)
model.eval()

model_path = 'dldemos/Transformer/model.pth'
model.load_state_dict(torch.load(model_path))

my_input = ['we', "should", "protect", "environment"]
x_batch = torch.LongTensor([[en2idx[x] for x in my_input]]).to(device)

cn_sentence = idx_to_sentence(x_batch[0], idx2en, True)
print(cn_sentence)

y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)


if __name__ == '__main__':
main()

一开始,还是先获取词表,并初始化模型。

1
2
3
4
5
6
7
8
9
10
11
12
def main():
device = 'cuda'
cn2idx, idx2cn = load_cn_vocab()
en2idx, idx2en = load_en_vocab()

model = Transformer(len(en2idx), len(cn2idx), 0, d_model, d_ff, n_layers,
heads, dropout_rate, maxlen)
model.to(device)
model.eval()

model_path = 'dldemos/Transformer/model.pth'
model.load_state_dict(torch.load(model_path))

之后,我们用自己定义的句子(要做好分词)代替原来的输入x_batch。如果要测试某个数据集,只要把这里x_batch换成测试集里的数据即可。
我们可以顺便把序号数组用idx_to_sentence转回英文,看看序号转换有没有出错。

1
2
3
4
5
my_input = ['we', "should", "protect", "environment"]
x_batch = torch.LongTensor([[en2idx[x] for x in my_input]]).to(device)

cn_sentence = idx_to_sentence(x_batch[0], idx2en, True)
print(cn_sentence)

这段代码会输出we should protect environment。这说明x_batch是我们想要的序号数组。

最后,我们利用Transformer自回归地生成句子,并输出句子。

1
2
3
4
5
6
7
8
9
10
11
y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)

要自回归地生成句子,我们先给句子填入无效字符<pad>,再把第一个字符换成句子开始字符<S>

1
2
3
y_input = torch.ones(batch_size, maxlen,
dtype=torch.long).to(device) * PAD_ID
y_input[0] = en2idx['<S>']

之后,我们循环调用Transformer,获取下一个单词的概率分布。我们可以认为,概率最大的那个单词就是模型预测的下一个单词。因此,我们可以用argmax获取预测的下一个单词的序号,填回y_input。这里的y_input和训练时那个y_batch是同一个东西。

1
2
3
4
5
6
# y_input = y_batch
with torch.no_grad():
for i in range(1, y_input.shape[1]):
y_hat = model(x_batch, y_input)
for j in range(batch_size):
y_input[j, i] = torch.argmax(y_hat[j, i - 1])

最后只要输出生成的句子即可。

1
2
output_sentence = idx_to_sentence(y_input[0], idx2cn, True)
print(output_sentence)

由于训练数据非常少,而且数据都来自新闻,我只好选择了一个比较常见的句子”we should protect environment”作为输入。模型翻译出了一个比较奇怪的结果。

1
<S> 要 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 保护 环境 的 生态 环境 落实 好 环境 </S> 环境 </S> 有效 保护 环境 </S>...

可以看出,模型确实学到了东西,能翻译出“要保护环境”。但是,这翻译的结果也太长太奇怪了。感觉是对训练数据过拟合了。当然,还是那句话,训练集里的数据太少。要提升模型性能并缓解过拟合,加数据集是最好的方法。这个结果起码说明我们Tranformer的编写没有问题。

在生成新句子的时候,我直接拿概率最高的单词当做预测的下一个单词。其实,还有一些更加高级的生成算法,比如Beam Search。如果模型训练得比较好,可以用这些高级一点的算法提高生成句子的质量。

我读了网上几份Transformer实现。这些实现在生成句子算BLEU score时,竟然直接输入测试句子的前n-1个单词,把输出的n-1个单词拼起来,作为模型的翻译结果。这个过程等价于告诉你前i个翻译答案,你去输出第i+1个单词,再把每个结果拼起来。这样写肯定是不合理的。正常来说应该是照着我这样自回归地生成翻译句子。大家参考网上的Transformer代码时要多加留心。

总结

只要读懂了 Transfomer 的论文,用 PyTorch 实现一遍 Transformer 是很轻松的。但是,代码实现中有非常多论文不会提及的细节,你自己实现时很容易踩坑。在这篇文章里,我完整地介绍了一个英中翻译 Transformer 的 PyTorch 实现,相信读者能够跟随这篇文章实现自己的 Transformer,并在代码实现的过程中加深对论文的理解。

再稍微总结一下代码实现中的一些值得注意的地方。代码中最大的难点是 mask 的实现。mask 的处理稍有闪失,就可能会让计算结果中遍布 NaN。一定要想清楚各个模块的 mask 是从哪来的,它们在注意力计算里是怎么被用上的。

另外,有两处地方的实现比较灵活。一处是位置编码的实现,一处是多头注意力中怎么描述“多头”。其他模块的实现都大差不差,千篇一律。

最后再提醒一句,要从头训练一个模型,一定要从小数据集上开始做。不然你训练个半天,结果差了,你不知道是数据有问题,还是代码有问题。我之前一直在使用很大的训练集,每次调试都非常麻烦,浪费了很多时间。希望大家引以为戒。

参考资料

感谢 https://github.com/P3n9W31/transformer-pytorch 提供的数据集。

一份简明的Transformer实现代码 https://github.com/hyunwoongko/transformer

一篇不错的Transformer实现教程 https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec

过期内容

我第一次写这篇文章时过于仓促,文章中有不少错误,实验部分也没写完。我后来把本文又重新修改了一遍,补充了实验部分。

我之前使用了一个较大的数据集,但发现做实验做得很慢,于是换了一个较小的数据集。以前的数据集预处理介绍就挪到这里了。

数据集与评测方法

在开启一个深度学习项目之初,要把任务定义好。准确来说,我们要明白这个任务是在完成一个怎样的映射,并准备一个用于评测的数据集,定义好评价指标。

英中翻译,这个任务非常明确,就是把英文的句子翻译成中文。英中翻译的数据集应该包含若干个句子对,每个句子对由一句英文和它对应的中文翻译组成。

中英翻译的数据集不是很好找。有几个比较出名的数据集的链接已经失效了,还有些数据集需要注册与申请后才能获取。我在中文NLP语料库仓库(https://github.com/brightmart/nlp_chinese_corpus)找到了中英文平行语料 translation2019zh。该语料库由520万对中英文语料构成,训练集516万对,验证集3.9万对。用作训练和验证中英翻译模型是足够了。

机器翻译的评测指标叫做BLEU Score。如果模型输出的翻译和参考译文有越多相同的单词、连续2个相同单词、连续3个相同单词……,则得分越高。

PyTorch 提供了便捷的API,我们可以用一行代码算完BLEU Score。

1
2
3
4
5
>>> from torchtext.data.metrics import bleu_score
>>> candidate_corpus = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']]
>>> references_corpus = [[['My', 'full', 'pytorch', 'test'], ['Completely', 'Different']], [['No', 'Match']]]
>>> bleu_score(candidate_corpus, references_corpus)
0.8408964276313782

数据清洗

得到数据集后,下一步要做的是对数据集做处理,把原始数据转化成能够输入神经网络的张量。对于图片,预处理可能是裁剪、缩放,使所有图片都有一样的大小;对于文本,预处理可能是分词、填充。

网盘上下载好 translation2019zh 数据集后,我们来一步一步清洗这个数据集。这个数据集只有两个文件translation2019zh_train.json, translation2019zh_valid.json,它们的结构如下:

text
1
2
3
4
{"english": <english>, "chinese": <chinese>}
{"english": <english>, "chinese": <chinese>}
{"english": <english>, "chinese": <chinese>}
...

这些json文件有点不合标准,每对句子由一行json格式的记录组成。english属性是英文句子,chinese属性是中文句子。比如:

text
1
{"english": "In Italy ...", "chinese": "在意大利 ..."}

因此,在读取数据时,我们可以用下面的代码提取每对句子。

1
2
3
4
5
6
import json

with open(json_path, 'r') as fp:
for line in fp:
line = json.loads(line)
english, chinese = line['english'], line['chinese']

这个数据集有一点不干净,有一些句子对的中英文句子颠倒过来了。为此,我们要稍微处理一下,把这些句子对翻转过来。如果一个英文句子不全由 ASCII 组成,则它可能是一个被标错的中文句子。

1
2
3
# Correct mislabeled data
if not english.isascii():
english, chinese = chinese, english

经过这一步,我们只得到了中英文的字符文本。而在NLP中,大部分处理的最小单位都是符号(token)——对于英文来说,符号是单词、标点;对于中文来说,符号是词语、标点。我们还需要一个符号化的过程。

英文符号化非常方便,torchtext 提供了非常便捷的英文分词 API。

1
2
3
4
from torchtext.data import get_tokenizer

tokenizer = get_tokenizer('basic_english')
english = tokenizer(english)

而中文分词方面,我使用了jieba库。该库可以直接 pip 安装。

1
pip install jieba

分词的 API 是 jieba.cut。由于分词的结果中,相邻的词之间有空格,我一股脑地把所有空白符给过滤掉了。

1
2
3
import jieba
chinese = list(jieba.cut(chinese))
chinese = [x for x in chinese if x not in {' ', '\t'}]

经过这些处理后,每句话被转换成了中文词语或英文单词的数组。整个处理代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def read_file(json_path):
english_sentences = []
chinese_sentences = []
tokenizer = get_tokenizer('basic_english')
with open(json_path, 'r') as fp:
for line in fp:
line = json.loads(line)
english, chinese = line['english'], line['chinese']
# Correct mislabeled data
if not english.isascii():
english, chinese = chinese, english
# Tokenize
english = tokenizer(english)
chinese = list(jieba.cut(chinese))
chinese = [x for x in chinese if x not in {' ', '\t'}]
english_sentences.append(english)
chinese_sentences.append(chinese)
return english_sentences, chinese_sentences

词语转序号

为了让计算机更方便地处理单词,我们还要把单词转换成序号。比如令apple为0号,banana为1号,则句子apple banana apple就转换成了0 1 0

给每一个单词选一个标号,其实就是要建立一个词典。一般来说,我们可以利用他人的统计结果,挑选最常用的一些英文单词和中文词语构成词典。不过,现在我们已经有了一个庞大的中英语料库了,我们可以直接从这个语料库中挑选出最常见的词构成词典。

根据上一步处理得到的句子数组sentences,我们可以用下面的 Python 代码统计出最常见的一些词语,把它们和4个特殊字符<sos>, <eos>, <unk>, <pad>(句子开始字符、句子结束字符、频率太少没有被加入词典的词语、填充字符)一起构成词典。统计字符出现次数是通过 Python 的 Counter 类实现的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from collections import Counter

def create_vocab(sentences, max_element=None):
"""Note that max_element includes special characters"""

default_list = ['<sos>', '<eos>', '<unk>', '<pad>']

char_set = Counter()
for sentence in sentences:
c_set = Counter(sentence)
char_set.update(c_set)

if max_element is None:
return default_list + list(char_set.keys())
else:
max_element -= 4
words_freq = char_set.most_common(max_element)
# pair array to double array
words, freq = zip(*words_freq)
return default_list + list(words)

准备好了词典后,我还编写了两个工具函数sentence_to_tensortensor_to_sentence,它们可以用于字符串数组与序号数组的互相转换。测试这些代码的脚本及其输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Dataset.py

def main():
en_sens, zh_sens = read_file(
'data/translation2019zh/translation2019zh_valid.json')
print(*en_sens[0:3])
print(*zh_sens[0:3])
en_vocab = create_vocab(en_sens, 10000)
zh_vocab = create_vocab(zh_sens, 30000)
print(list(en_vocab)[0:10])
print(list(zh_vocab)[0:10])

en_tensors = sentence_to_tensor(en_sens, en_vocab)
zh_tensors = sentence_to_tensor(zh_sens, zh_vocab)

print(tensor_to_sentence(en_tensors[0], en_vocab, True))
print(tensor_to_sentence(zh_tensors[0], zh_vocab))
text
1
2
3
4
5
6
['slowly', 'and', 'not', 'without', 'struggle', ',', 'america', 'began', 'to', 'listen', '.'] ...]
['美国', '缓慢', '地', '开始', '倾听', ',', '但', '并非', '没有', '艰难曲折', '。'] ...]
['<sos>', '<eos>', '<unk>', '<pad>', 'the', '.', ',', 'of', 'and', 'to']
['<sos>', '<eos>', '<unk>', '<pad>', '的', ',', '。', '在', '了', '和']
slowly and not without struggle , america began to listen .
美国缓慢地开始倾听,但并非没有<unk>。

在这一步中,有一个重要的参数:词典的大小。显然,词典越大,能处理的词语越多,但训练速度也会越慢。由于这个项目只是一个用于学习的demo,我设置了比较小的词典大小。想提升整个模型的性能的话,调大词典大小是一个最快的方法。

生成 Dataloader

都说程序员是新时代的农民工,这非常有道理。因为,作为程序员,你免不了要写一些繁重、无聊的数据处理脚本。还好,写完这些无聊的预处理代码后,总算可以使用 PyTorch 的 API 写一些有趣的代码了。

把词语数组转换成序号句子数组后,我们要考虑怎么把序号句子数组输入给模型了。文本数据通常长短不一,为了一次性处理一个 batch 的数据,要把短的句子填充,使得一批句子长度相等。写 Dataloader 时最主要的工作就是填充并对齐句子。

先看一下Dataset的写法。上一步得到的序号句子数组可以塞进Dataset里。注意,每个句子的前后要加上表示句子开始和结束的特殊符号。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
SOS_ID = 0
EOS_ID = 1
UNK_ID = 2
PAD_ID = 3

class TranslationDataset(Dataset):

def __init__(self, en_tensor: np.ndarray, zh_tensor: np.ndarray):
super().__init__()
assert len(en_tensor) == len(zh_tensor)
self.length = len(en_tensor)
self.en_tensor = en_tensor
self.zh_tensor = zh_tensor

def __len__(self):
return self.length

def __getitem__(self, index):
x = np.concatenate(([SOS_ID], self.en_tensor[index], [EOS_ID]))
x = torch.from_numpy(x)
y = np.concatenate(([SOS_ID], self.zh_tensor[index], [EOS_ID]))
y = torch.from_numpy(y)
return x, y

接下来看一下 DataLoader 的写法。在创建 Dataloader 时,最重要的是 collate_fn 的编写,这个函数决定了怎么把多条数据合成一个等长的 batch。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_dataloader(en_tensor: np.ndarray,
zh_tensor: np.ndarray,
batch_size=16):

def collate_fn(batch):
...

dataset = TranslationDataset(en_tensor, zh_tensor)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)

return dataloader

collate_fn 的输入是多个 dataset __getitem__ 的返回结果构成的数组。对于我们的 dataset 来说,collate_fn 的输入是 [(x1, y1), (x2, y2), ...] 。我们可以用 zip(*batch) 把二元组数组拆成两个数组 x, y

collate_fn 的输出就是将来 dataloader 的输出。PyTorch 提供了 pad_sequence 函数用来把一批数据填充至等长。

1
2
3
4
5
6
7
8
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
x, y = zip(*batch)
x_pad = pad_sequence(x, batch_first=True, padding_value=PAD_ID)
y_pad = pad_sequence(y, batch_first=True, padding_value=PAD_ID)

return x_pad, y_pad

实现完collate_fn后,我们就可以得到了DataLoader。这样,数据集预处理部分大功告成。

近两年,有许多图像生成类任务的前沿工作都使用了一种叫做”codebook”的机制。追溯起来,codebook机制最早是在VQ-VAE论文中提出的。相比于普通的VAE,VQ-VAE能利用codebook机制把图像编码成离散向量,为图像生成类任务提供了一种新的思路。VQ-VAE的这种建模方法启发了无数的后续工作,包括声名远扬的Stable Diffusion。

在这篇文章中,我将先以易懂的逻辑带领大家一步一步领悟VQ-VAE的核心思想,再介绍VQ-VAE中关键算法的具体形式,最后把VQ-VAE的贡献及其对其他工作的影响做一个总结。通过阅读这篇文章,你不仅能理解VQ-VAE本身的原理,更能知道如何将VQ-VAE中的核心机制活学活用。

从 AE 到 VQ-VAE

为什么VQ-VAE想要把图像编码成离散向量?让我们从最早的自编码器(Autoencoder, AE)开始一步一步谈起。AE是一类能够把图片压缩成较短的向量的神经网络模型,其结构如下图所示。AE包含一个编码器$e()$和一个解码器$d()$。在训练时,输入图像$\mathbf{x}$会被编码成一个较短的向量$\mathbf{z}$,再被解码回另一幅长得差不多的图像$\hat{\mathbf{x}}$。网络的学习目标是让重建出来的图像$\hat{\mathbf{x}}$和原图像$\mathbf{x}$尽可能相似。

解码器可以把一个向量解码成图片。换一个角度看,解码器就是一个图像生成模型,因为它可以根据向量来生成图片。那么,AE可不可以用来做图像生成呢?很可惜,AE的编码器编码出来的向量空间是不规整的。也就是说,解码器只认识经编码器编出来的向量,而不认识其他的向量。如果你把自己随机生成出来的向量输入给解码器,解码器是生成不出有意义的图片的。AE不能够随机生成图片,所以它不能很好地完成图像生成任务,只能起到把图像压缩的作用。

AE离图像生成只差一步了。只要AE的编码空间比较规整,符合某个简单的数学分布(比如最常见的标准正态分布),那我们就可以从这个分布里随机采样向量,再让解码器根据这个向量来完成随机图片生成了。VAE就是这样一种改进版的AE。它用一些巧妙的方法约束了编码向量$\mathbf{z}$,使得$\mathbf{z}$满足标准正态分布。这样,解码器不仅认识编码器编出的向量,还认识其他来自标准正态分布的向量。训练完成后,我们就可以扔掉编码器,用来自标准正态分布的随机向量和解码器来实现随机图像生成了。

VAE的实现细节就不在这里赘述了,是否理解它对理解VQ-VAE没有影响。我们只需知道VAE可以把图片编码成符合标准正态分布的向量即可。让向量符合标准正态分布的原因是方便随机采样。同时,需要强调的是,VAE编码出来的向量是连续向量,也就是向量的每一维都是浮点数。如果把向量的某一维稍微改动0.0001,解码器还是认得这个向量,并且会生成一张和原向量对应图片差不多的图片。

但是,VAE生成出来的图片都不是很好看。VQ-VAE的作者认为,VAE的生成图片之所以质量不高,是因为图片被编码成了连续向量。而实际上,把图片编码成离散向量会更加自然。比如我们想让画家画一个人,我们会说这个是男是女,年龄是偏老还是偏年轻,体型是胖还是壮,而不会说这个人性别是0.5,年龄是0.6,体型是0.7。因此,VQ-VAE会把图片编码成离散向量,如下图所示。

把图像编码成离散向量后,又会带来两个新的问题。第一个问题是,神经网络会默认输入满足一个连续的分布,而不善于处理离散的输入。如果你直接输入0, 1, 2这些数字,神经网络会默认1是一个处于0, 2中间的一种状态。为了解决这一问题,我们可以借鉴NLP中对于离散单词的处理方法。为了处理离散的输入单词,NLP模型的第一层一般都是词嵌入层,它可以把每个输入单词都映射到一个独一无二的连续向量上。这样,每个离散的数字都变成了一个特别的连续向量了。

我们可以把类似的嵌入层加到VQ-VAE的解码器前。这个嵌入层在VQ-VAE里叫做”embedding space(嵌入空间)”,在后续文章中则被称作”codebook”。

离散向量的另一个问题是它不好采样。回忆一下,VAE之所以把图片编码成符合正态分布的连续向量,就是为了能在图像生成时把编码器扔掉,让随机采样出的向量也能通过解码器变成图片。现在倒好,VQ-VAE把图片编码了一个离散向量,这个离散向量构成的空间是不好采样的。VQ-VAE不是面临着和AE一样的问题嘛。

这个问题是无解的。没错!VQ-VAE根本不是一个图像生成模型。它和AE一样,只能很好地完成图像压缩,把图像变成一个短得多的向量,而不支持随机图像生成。VQ-VAE和AE的唯一区别,就是VQ-VAE会编码出离散向量,而AE会编码出连续向量。

可为什么VQ-VAE会被归类到图像生成模型中呢?这是因为VQ-VAE的作者利用VQ-VAE能编码离散向量的特性,使用了一种特别的方法对VQ-VAE的离散编码空间采样。VQ-VAE的作者之前设计了一种图像生成网络,叫做PixelCNN。PixelCNN能拟合一个离散的分布。比如对于图像,PixelCNN能输出某个像素的某个颜色通道取0~255中某个值的概率分布。这不刚好嘛,VQ-VAE也是把图像编码成离散向量。换个更好理解的说法,VQ-VAE能把图像映射成一个「小图像」。我们可以把PixelCNN生成图像的方法搬过来,让PixelCNN学习生成「小图像」。这样,我们就可以用PixelCNN生成离散编码,再利用VQ-VAE的解码器把离散编码变成图像。

让我们来整理一下VQ-VAE的工作过程。

  1. 训练VQ-VAE的编码器和解码器,使得VQ-VAE能把图像变成「小图像」,也能把「小图像」变回图像。
  2. 训练PixelCNN,让它学习怎么生成「小图像」。
  3. 随机采样时,先用PixelCNN采样出「小图像」,再用VQ-VAE把「小图像」翻译成最终的生成图像。

到这里,我们已经学完了VQ-VAE的核心思想。让我们来总结一下。VQ-VAE不是一个VAE,而是一个AE。它的目的是把图像压缩成离散向量。或者换个角度说,它提供了把大图像翻译成「小图像」的方法,也提供了把「小图像」翻译成大图像的方法。这样,一个随机生成大图像的问题,就被转换成了一个等价的随机生成一个较小的「图像」的问题。有一些图像生成模型,比如PixelCNN,更适合拟合离散分布。可以用它们来完成生成「小图像」的问题,填补上VQ-VAE生成图片的最后一片空缺。

VQ-VAE 设计细节

在上一节中,我们虽然认识了VQ-VAE的核心思想,但略过了不少实现细节,比如:

  • VQ-VAE的编码器怎么输出离散向量。
  • VQ-VAE怎么优化编码器和解码器。
  • VQ-VAE怎么优化嵌入空间。

在这一节里,我们来详细探究这些细节。

输出离散编码

想让神经网络输出一个整数,最简单的方法是和多分类模型一样,输出一个Softmax过的概率分布。之后,从概率分布里随机采样一个类别,这个类别的序号就是我们想要的整数。比如在下图中,我们想得到一个由3个整数构成的离散编码,就应该让编码器输出3组logit,再经过Softmax与采样,得到3个整数。

但是,这么做不是最高效的。得到离散编码后,下一步我们又要根据嵌入空间把离散编码转回一个向量。可见,获取离散编码这一步有一点多余。能不能把编码器的输出张量(它之前的名字叫logit)、解码器的输入张量embedding、嵌入空间直接关联起来呢?

VQ-VAE使用了如下方式关联编码器的输出与解码器的输入:假设嵌入空间已经训练完毕,对于编码器的每个输出向量$z_e(x)$,找出它在嵌入空间里的最近邻$z_q(x)$,把$z_e(x)$替换成$z_q(x)$作为解码器的输入。

求最近邻,即先计算向量与嵌入空间$K$个向量每个向量的距离,再对距离数组取一个argmin,求出最近的下标(比如图中的0, 1, 1),最后用下标去嵌入空间里取向量。下标构成的数组(比如图中的[0, 1, 1])也正是VQ-VAE的离散编码。

就这样,我们知道了VQ-VAE是怎么生成离散编码的。VQ-VAE的编码器其实不会显式地输出离散编码,而是输出了多个「假嵌入」$z_e(x)$。之后,VQ-VAE对每个$z_e(x)$在嵌入空间里找最近邻,得到真正的嵌入$z_q(x)$,把$z_q(x)$作为解码器的输入。

虽然我们现在能把编码器和解码器拼接到一起,但现在又多出了一个问题:怎么让梯度从解码器的输入$z_q(x)$传到$z_e(x)$?从$z_e(x)$到$z_q(x)$的变换是一个从数组里取值的操作,这个操作是求不了导的。我们在下一小节里来详细探究一下怎么优化VQ-VAE的编码器和解码器。

优化编码器和解码器

为了优化编码器和解码器,我们先来制订一下VQ-VAE的整体优化目标。由于VQ-VAE其实是一个AE,误差函数里应该只有原图像和目标图像的重建误差。

或者非要从VAE的角度说也行。VQ-VAE相当于输出了一个one-hot离散分布。假设输入图像$x$的离散编码$z$是$k$,则分布中仅有$q(z=k|x)=1$,$q(z=others|x)=0$。令离散编码$z$的先验分布是均匀分布(假设不知道输入图像$x$,每个离散编码取到的概率是等同的),则先验分布$q(z)$和后验分布$q(z|x)$的KL散度是常量。因此,KL散度项不用算入损失函数里。理解此处的数学推导意义不大,还不如直接理解成VQ-VAE其实是一个AE。

但直接拿这个误差来训练是不行的。误差中,$z_q(x)$是解码器的输入。从编码器输出$z_e(x)$到$z_q(x)$这一步是不可导的,误差无法从解码器传递到编码器上。要是可以把$z_q(x)$的梯度直接原封不动地复制到$z_e(x)$上就好了。

VQ-VAE使用了一种叫做”straight-through estimator”的技术来完成梯度复制。这种技术是说,前向传播和反向传播的计算可以不对应。你可以为一个运算随意设计求梯度的方法。基于这一技术,VQ-VAE使用了一种叫做$sg$(stop gradient,停止梯度)的运算:

也就是说,前向传播时,$sg$里的值不变;反向传播时,$sg$按值为0求导,即此次计算无梯度。(反向传播其实不会用到式子的值,只会用到式子的梯度。反向传播用到的loss值是在前向传播中算的)。

基于这种运算,我们可以设计一个把梯度从$z_e(x)$复制到$z_q(x)$的误差:

也就是说,前向传播时,就是拿解码器输入$z_q(x)$来算梯度。

而反向传播时,按下面这个公式求梯度,等价于把解码器的梯度全部传给$z_e(x)$。

这部分的PyTorch实现如下所示。在PyTorch里,(x).detach()就是$sg(x)$,它的值在前向传播时取x,反向传播时取0

1
L = x - decoder(z_e + (z_q - z_e).detach())

通过这一技巧,我们完成了梯度的传递,可以正常地训练编码器和解码器了。

优化嵌入空间

到目前为止,我们的讨论都是建立在嵌入空间已经训练完毕的前提上的。现在,我们来讨论一下嵌入空间的训练方法。

嵌入空间的优化目标是什么呢?嵌入空间的每一个向量应该能概括一类编码器输出的向量,比如一个表示「青年」的向量应该能概括所有14-35岁的人的照片的编码器输出。因此,嵌入空间的向量应该和其对应编码器输出尽可能接近。如下面的公式所示,$z_e(x)$是编码器的输出向量,$z_q(x)$是其在嵌入空间的最近邻向量。

但作者认为,编码器和嵌入向量的学习速度应该不一样快。于是,他们再次使用了停止梯度的技巧,把上面那个误差函数拆成了两部分。其中,$\beta$控制了编码器的相对学习速度。作者发现,算法对$\beta$的变化不敏感,$\beta$取0.1~2.0都差不多。

其实,在论文中,作者分别讨论了上面公式里的两个误差。第一个误差来自字典学习算法里的经典算法Vector Quantisation(VQ),也就是VQ-VAE里的那个VQ,它用于优化嵌入空间。第二个误差叫做专注误差,它用于约束编码器的输出,不让它跑到离嵌入空间里的向量太远的地方。

这样,VQ-VAE总体的损失函数可以写成:(由于算上了重建误差,我们多加一个$\alpha$用于控制不同误差之间的比例)

总结

VQ-VAE是一个把图像编码成离散向量的图像压缩模型。为了让神经网络理解离散编码,VQ-VAE借鉴了NLP的思想,让每个离散编码值对应一个嵌入,所有的嵌入都存储在一个嵌入空间(又称”codebook”)里。这样,VQ-VAE编码器的输出是若干个「假嵌入」,「假嵌入」会被替换成嵌入空间里最近的真嵌入,输入进解码器里。

VQ-VAE的优化目标由两部分组成:重建误差和嵌入空间误差。重建误差为输入图片和重建图片的均方误差。为了让梯度从解码器传到编码器,作者使用了一种巧妙的停止梯度算子,让正向传播和反向传播按照不同的方式计算。嵌入空间误差为嵌入和其对应的编码器输出的均方误差。为了让嵌入和编码器以不同的速度优化,作者再次使用了停止梯度算子,把嵌入的更新和编码器的更新分开计算。

训练完成后,为了实现随机图像生成,需要对VQ-VAE的离散分布采样,再把采样出来的离散向量对应的嵌入输入进解码器。VQ-VAE论文使用了PixelCNN来采样离散分布。实际上,PixelCNN不是唯一一种可用的拟合离散分布的模型。我们可以把它换成Transformer,甚至是diffusion模型。如果你当年看完VQ-VAE后立刻把PixelCNN换成了diffusion模型,那么恭喜你,你差不多提前设计出了Stable Diffusion。

可见,VQ-VAE最大的贡献是提供了一种图像压缩思路,把生成大图像的问题转换成了一个更简单的生成「小图像」的问题。图像压缩成离散向量时主要借助了嵌入空间,或者说”codebook”这一工具。这种解决问题的思路可以应用到所有图像生成类任务上,比如超分辨率、图像修复、图像去模糊等。所以近两年我们能看到很多使用了codebook的图像生成类工作。

参考资料

PixelCNN的介绍可以参见我之前的文章:详解PixelCNN大家族。

VQ-VAE的论文为Neural Discrete Representation Learning。这篇文章不是很好读懂,建议直接读我的这篇解读。再推荐另一份还不错的中文解读 https://www.spaces.ac.cn/archives/6760。

图像生成是一个较难建模的任务。为此,我们要用GAN、VAE、Diffusion等精巧的架构来建模图像生成。可是,在NLP中,文本生成却有一种非常简单的实现方法。NLP中有一种基础的概率模型——N元语言模型。N元语言模型可以根据句子的前几个字预测出下一个字的出现概率。比如看到「我爱吃苹……」这句话的前几个字,我们不难猜出下一个字大概率是「果」字。利用N元语言模型,我们可以轻松地实现一个文本生成算法:输入空句子,采样出第一个字;输入第一个字,采样出第二个字;输入前两个字,输出第三个字……以此类推。

既然如此,我们可不可以把相同的方法搬到图像生成里呢?当然可以。虽然图像是二维的数据,不像一维的文本一样有先后顺序,但是我们可以强行给图像的每个像素规定一个顺序。比如,我们可以从左到右,从上到下地给图像标上序号。这样,从逻辑上看,图像也是一个一维数据,可以用NLP中的方法来按照序号实现图像生成了。

PixelCNN就是一个使用这种方法生成图像的模型。可为什么PixelCNN的名气没有GAN、VAE那么大?为什么PixelCNN可以用CNN而不是RNN来处理一维化图像?为什么PixelCNN是一种「自回归模型」?别急,在这篇文章中,我们将认识PixelCNN及其改进模型Gated PixelCNN和PixelCNN++,并认真学习它们的实现代码。看完文章后,这些问题都会迎刃而解。

PixelCNN

如前所述,PixelCNN借用了NLP里的方法来生成图像。模型会根据前i - 1个像素输出第i个像素的概率分布。训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数;采样时,直接从预测的概率分布里采样出第i个像素。根据这些线索,我们来尝试自己「发明」一遍PixelCNN。

这种模型最朴素的实现方法,是输入一幅图像的前i - 1个像素,输出第i个像素的概率分布,即第i个像素取某种颜色的概率的数组。为了方便讨论,我们先只考虑单通道图像,每个像素的颜色取值只有256种。因此,准确来说,模型的输出是256个经过softmax的概率。这样,我们得到了一个V1.0版本的模型。

等等,模型不是叫「PixelCNN」吗?CNN跑哪去了?的确,对于图像数据,最好还是使用CNN,快捷又有效。因此,我们应该修改模型,令模型的输入为整幅图像和序号i。我们根据序号i,过滤掉ii之后的像素,用CNN处理图像。输出部分还是保持一致。

V2.0并不是最终版本,我们可以暂时不用考虑实现细节,比如这里的「过滤」是怎么实现的。硬要做的话,这种过滤也可以暴力实现:把无效像素初始化为0,每次卷积后再把无效像素置0。

改进之后,V2.0版本的模型确实能快速计算第i个像素的概率分布了。可是,CNN是很擅长同时生成一个和原图像长宽相同的张量的,只算一个像素的概率分布还称不上高效。所以,我们可以让模型输入一幅图像,同时输出图像每一处的概率分布。

这次的改进并不能加速采样。但是,在训练时,由于整幅训练图像已知,我们可以在一次前向传播后得到图像每一处的概率分布。假设图像有N个像素,我们就等于是在并行地训练N个样本,训练速度快了N倍!

这种并行训练的想法和Transformer如出一辙。

V3.0版本的PixelCNN已经和论文里的PixelCNN非常接近了,我们来探讨一下网络的实现细节。相比普通的CNN,PixelCNN有一个特别的约束:第i个像素只能看到前i-1个像素的信息,不能看到第i个像素及后续像素的信息。对于V2.0版本只要输出一个概率分布的PixelCNN,我们可以通过一些简单处理过滤掉第i个像素之后的信息。而对于并行输出所有概率分布的V3.0版本,让每个像素都忽略后续像素的信息的方法就不是那么显然了。

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。具体来说,PixelCNN使用了两类掩码卷积,我们把两类掩码卷积分别称为「A类」和「B类」。二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。A类和B类的唯一区别在于卷积核的中心像素是否产生贡献。CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积。如下图所示。

为什么要先用一次A类掩码卷积,再每次使用B类掩码卷积呢?我们不妨来做一个实验。对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)。

不难看出,经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。再经过两次B类卷积,中心像素能够看到左上角大部分像素的信息。这满足PixelCNN的约束。

而如果一直使用A类卷积,每次卷积后中心像素都会看漏一些信息(不妨对比下面这张示意图和上面那张示意图)。多卷几层后,中心像素的值就会和输入图像毫无关系。

只是用B类卷积也是不行的。显然,如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。这下,我们能明白为什么只能先用一次A类卷积,再用若干次B类卷积了。

利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。除此之外,PixelCNN就没有什么特别的地方了。我们可以用任意一种CNN架构来实现PixelCNN。PixelCNN论文使用了一种类似于ResNet的架构。其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

到目前为止,我们已经成功搭建了处理单通道图像的PixelCNN。现在,我们来尝试把它推广到多通道图像上。相比于单通道图像,多通道图像只不过是一个像素由多个颜色分量组成。我们可以把一个像素的颜色分量看成是子像素。在定义约束关系时,我们规定一个子像素只由它之前的子像素决定。比如对于RGB图像,R子像素由它之前所有像素决定,G子像素由它的R子像素和之前所有像素决定,B子像素由它的R、G子像素和它之前所有像素决定。生成图像时,我们一个子像素一个子像素地生成。

把我们的PixelCNN V3.0推广到RGB图像时,我们要做的第一件事就是修改网络的通道数量。由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量,即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

这里说网络中间的通道数要乘3只是一种方便理解的说法。实际上,中间的通道数可以随意设置,是不是3的倍数都无所谓,只是所有通道在逻辑上被分成了3组。我们稍后会利用到「中间结果的通道数应该能被拆成3组」这一性质。

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  1. 对于通道间的约束,我们要在o, i两个维度上设置掩码。设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3,即o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:oi1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i。序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。我们对输入通道组和输出通道组之间进行约束。对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3;对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3

  2. 对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

就这样,修改了通道数,修改了卷积核的掩码后,我们成功实现了论文里的PixelCNN。让我们把这个过程总结一下。PixelCNN的核心思想是给图像的子像素定义一个先后顺序,之后让每个子像素的颜色取值分布由之前所有的子像素决定。实现PixelCNN时,可以用任意一种CNN架构,并注意两点:

  1. 网络的输出是一个经softmax的概率分布。
  2. 网络的所有卷积层要替换成带掩码的卷积层,第一个卷积层用A类掩码,后面的用B类掩码。

学完了PixelCNN,我们在闲暇之余来谈一谈PixelCNN和其他生成网络的对比情况。精通数学的人,会把图像生成问题看成学习一个图像的分布。每次生成一张图片,就是在图像分布里随机采样一个张量。学习一个分布,最便捷的方法是定义一个带参数$\theta$的概率模型$P_\theta$,最大化来自数据集的图像$\mathbf{x}$的概率$P_\theta(\mathbf{x})$。

可问题来了:一个又方便采样,又能计算概率的模型不好设计。VAE和Diffusion建模了把一个来自正态分布的向量$\mathbf{z}$变化成$\mathbf{x}$的过程,并使用了统计学里的变分推理,求出了$P_\theta(\mathbf{x})$的一个下界,再设法优化这个下界。GAN干脆放弃了概率模型,直接拿一个神经网络来评价生成的图像好不好。

PixelCNN则正面挑战了建立概率模型这一任务。它把$P_\theta(\mathbf{x})$定义为每个子像素出现概率的乘积,而每个子像素的概率仅由它之前的子像素决定。

由于我们可以轻松地用神经网络建模每个子像素的概率分布并完成采样,PixelCNN的采样也是很方便的。我们可以说PixelCNN是一个既方便采样,又能快速地求出图像概率的模型。

相比与其他生成模型,PixelCNN直接对$P_\theta(\mathbf{x})$建模,在和概率相关的指标上表现优秀。很可惜,能最大化数据集的图像的出现概率,并不代表图像的生成质量就很优秀。因此,一直以来,以PixelCNN为代表的对概率直接建模的生成模型没有受到过多的关注。可能只有少数必须要计算图像概率分布的任务才会用到PixelCNN。

除了能直接计算图像的概率外,PixelCNN还有一大特点:PixelCNN能输出离散的颜色值。VAE和GAN这些模型都是把图像的颜色看成一个连续的浮点数,模型的输入和输出的取值范围都位于-1到1之间(有些模型是0到1之间)。而PixelCNN则输出的是像素取某个颜色的概率分布,它能描述的颜色是有限而确定的。假如我们是在生成8位单通道图像,那网络就只输出256个离散的概率分布。能生成离散输出这一特性启发了后续很多生成模型。另外,这一特性也允许我们指定颜色的亮度级别。比如对于黑白手写数字数据集MNIST,我们完全可以用黑、白两种颜色来描述图像,而不是非得用256个灰度级来描述图像。减少亮度级别后,网络的训练速度能快上很多。

在后续的文献中,PixelCNN被归类为了自回归生成模型。这是因为PixelCNN在生成图像时,要先输入空图像,得到第一个像素;把第一个像素填入空图像,输入进模型,得到第二个像素……。也就是说,一个图像被不断扔进模型,不断把上一时刻的输出做为输入。这种用自己之前时刻的状态预测下一个状态的模型,在统计学里被称为自回归模型。如果你在其他图像生成文献中见到了「自回归模型」这个词,它大概率指的就是PixelCNN这种每次生成一个像素,该像素由之前所有像素决定的生成模型。

Gated PixelCNN

首篇提出PixelCNN的论文叫做Pixel Recurrent Neural Networks。没错!这篇文章的作者提出了一种叫做PixelRNN的架构,PixelCNN只是PixelRNN的一个变种。可能作者一开始也没指望PixelCNN有多强。后来,人们发现PixelCNN的想法还挺有趣的,但是原始的PixelCNN设计得太烂了,于是开始着手改进原始的PixelCNN。

PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,在我们刚刚的实验中,中心像素看不到右上角三个本应该能看到的像素。哪怕你对用B类卷积多卷几次,右上角的视野盲区都不会消失。

为此,PixelCNN论文的作者们又打了一些补丁,发表了Conditional Image Generation with PixelCNN Decoders这篇论文。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。如下图所示,Gated PixelCNN使用了两种卷积——垂直卷积和水平卷积——来分别维护一个像素上侧的信息和左侧的信息。垂直卷积的结果只是一些临时量,而水平卷积的结果最终会被网络输出。可以看出,使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

除此之外,Gated PixelCNN还把网络中的激活函数从ReLU换成了LSTM的门结构。Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。此外,模块右侧那部分还有一个残差连接。

除了上面的两项改动,Gated PixelCNN还做出了其他的一些改动。比如,Gated PixelCNN支持带约束的图像生成,比如根据文字生成图片、根据类别生成图片。用于约束生成的向量$\mathbf{h}$会被输入进网络每一层的激活函数中。当然,这些改动不是为了提升原PixelCNN的性能。

PixelCNN++

之后,VAE的作者也下场了,提出了一种改进版的PixelCNN,叫做PixelCNN++。这篇论文没有多余的废话,在摘要中就简明地指出了PixelCNN++的几项改动:

  1. 使用logistic分布代替256路softmax
  2. 简化RGB子像素之间的约束关系
  3. 使用U-Net架构
  4. 使用dropout正则化

这几项改动中,第一项改动是最具启发性的,这一技巧可以拓展到其他任务上。让我们主要学习一下第一项改动,并稍微浏览一下其他几项改动。

离散logistic混合似然

原PixelCNN使用256路softmax建模一个像素的颜色概率分布。这么做确实能让模型更加灵活,但有若干缺点。首先,计算这么多的概率值会占很多内存;其次,由于每次训练只有一个位置的标签为1,其他255个位置的标签都是0,模型可学习参数的梯度会很稀疏;最后,在这种概率分布方式下,256种颜色是分开考虑的,这导致模型并不知道相邻的颜色比较相似(比如颜色值128和127、129比较相似)这一事实。总之,用softmax独立地表示各种颜色有着诸多的不足。

作者把颜色的概率分布建模成了连续分布,一下子克服掉了上述所有难题。让我们来仔细看一下新概率分布的定义方法。

首先,新概率分布使用到的连续分布叫做logistic分布。它有两个参数:均值$\mu$和方差$s^2$。它的概率密度函数为:

logistic分布的概率密度函数看起来比较复杂。但是,如果对这个函数积分,得到的累计分布函数就是logistic函数。如果令均值为0,方差为1,则logistic函数就是我们熟悉的sigmoid函数了。

接着,每个分布可能是$K$个参数不同的logistic分布中的某一个,选择某个logistic分布的概率由$\pi_i$表示。比如$K=2$,$\pi_1 = 0.3, \pi_2=0.7$,就说明有两个可选的logisti分布,每个分布有30%的概率会使用1号logistic分布,有70%的概率会使用2号logistic分布。 这里的$\pi_i$和原来256路softmax的输出的意义一样,都是选择某个东西的概率。当然,$K$会比256要小很多,不然这种改进就起不到减小计算量的作用了。设一个输出颜色为$v$,它的数学表达式为:

可logsitc分布是一个连续分布,而我们想得到256个颜色中某个颜色的概率,即得到一个离散的分布。因此,在最后一步,我们要从上面这个连续分布里得到一个离散的分布。我们先不管$K$和$\pi_i$,只考虑有一个logistic分布的情况。根据统计学知识可知,要从连续分布里得到一个离散分布,可以把定义域拆成若干个区间,对每个区间的概率求积分。在我们的例子里,我们可以把实数集拆成256个区间,令$(-\infty, 0.5]$为第1个区间,$(0.5, 1.5]$为第2个区间,……,$(253.5, 254.5]$为第255个区间, $(254.5, +\infty)$为第256个区间。

对概率密度函数求积分,就是在累积分布函数上做差。因此,对于某个离散颜色值$x\in[0, 255], x\in \mathbb{N}$,已知一个logistic分布$logistic(\mu, s)$,则这个颜色值的出现概率是:

其中,$\sigma()$是sigmoid函数。$\sigma((x-\mu)/s)$就是分布的累积分布函数。

可以看出,使用这种区间划分方法,位于0处和位于255处的颜色的概率相对会高一点。这一特点符合作者统计出的CIFAR-10里的颜色分布规律。

当有$K$个logistic分布时,只要把各个分布的概率做一个加权和就行(公式省略掉了$x$位于边界处的情况)。

至此,我们已经知道了怎么用一个「离散logistic混合似然」来建模颜色的概率分布了。这个更高级的颜色分布以logistic分布为基础,以比例(概率)$\pi_i$混合了$K$个logstic分布,并用巧妙的方法把连续分布转换成了离散分布。

简化RGB子像素之间的约束关系

在原PixelCNN中,生成一个像素的RGB三个子像素时,为了保证子像素之间的约束,我们要把模型中所有特征图的通道分成三组,并用掩码来维持三组通道间的约束。这样做太麻烦了。因此,PixelCNN++对约束做了一定的简化:根据之前所有像素,网络一次性输出三个子像素的均值和方差,而不用掩码区分三个子像素的信息。当然,只是这样做是不够好的——G子像素缺少了R子像素的信息,B子像素缺少了R、G子像素的信息。为了弥补信息的缺失,PixelCNN会为每个像素额外输出三个参数$\alpha, \beta, \gamma$,$\alpha$描述R对G子像素的约束关系,$\beta$描述R对B的约束关系,$\gamma$描述G对B的约束关系。

让我们来用公式更清晰地描述这一过程。对于某个像素的第$i$个logistic分布,网络会输出10个参数:$\pi, \mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$。$\pi$就是之前见过的选择第$i$个分布的概率,$\mu_r, \mu_g, \mu_b$是网络输出的三个子像素的均值,$s_r, s_g, s_b$是网络输出的三个子像素的标准差,$\alpha, \beta, \gamma$描述子像素之间的约束。

由于缺少了其他子像素的信息,网络直接输出的$\mu_g, \mu_b$是不准的。我们假设子像素之间仅存在简单的线性关系。这样,可以用下面的公式更新$\mu_g$和$\mu_b$:

更新后的$\mu_g$和$\mu_b$才是训练和采样时使用的最终均值。

你会不会疑惑上面那个公式里的$r$和$g$是哪里来的?别忘了,虽然子像素之间的约束被简化了,但是三个子像素还是按先后顺序生成的。在训练时,我们是知道所有子像素的真值的,公式里的$r$和$g$来自真值;而在采样时,我们会先用神经网络生成三个子像素的均值和方差,再采样$r$,把采样的$r$套入公式采样出$g$,最后把采样的$r,g$套入公式采样出$b$.

使用U-Net架构

PixelCNN++的网络架构是一个三级U-Net,即网络先下采样两次再上采样两次,同级编码器(下采样部分)的输出会连到解码器(上采样部分)的输入上。这个U-Net和其他任务中的U-Net没什么太大的区别。

使用Dropout

过拟合会导致生成图像的观感不好。为此,PixelCNN++采用了Dropout正则化方法,在每个子模块的第一个卷积后面加了一个Dropout。

除了这些改动外,PixelCNN++还使用了类似于Gated PixelCNN里垂直卷积和水平卷积的设计,以消除原PixelCNN里的视野盲区。当然,这点不算做本文的主要贡献。

总结

PixelCNN把文本生成的想法套入了图像生成中,令子像素的生成有一个先后顺序。为了在维护先后顺序的同时执行并行训练,PixelCNN使用了掩码卷积。这种并行训练与掩码的设计和Transformer一模一样。如果你理解了Transformer,就能一眼看懂PixelCNN的原理。

相比与其他的图像生成模型,以PixelCNN为代表的自回归模型在生成效果上并不优秀。但是,PixelCNN有两个特点:能准确计算某图像在模型里的出现概率(准确来说在统计学里叫做「似然」)、能生成离散的颜色输出。这些特性为后续诸多工作铺平了道路。

原版的PixelCNN有很多缺陷,后续很多工作对其进行了改进。Gated PixelCNN主要消除了原PixelCNN里的视野盲区,而PixelCNN++提出了一种泛用性很强的用连续分布建模离散颜色值的方法,并用简单的线性约束代替了原先较为复杂的用神经网络表示的子像素之间的约束。

PixelCNN相关的知识难度不高,了解本文介绍的内容足矣。PixelCNN也不是很常见的架构,复现代码的优先级不高,有时间的话阅读一下本文附录中的代码即可。另外,PixelCNN的代码实现里有一个重要的知识点。这个知识点几乎不会在论文和网上的文章里看到,但它对实现是否成功有着重要的影响。如果你对新知识感兴趣,推荐去读一下附录中对其的介绍。

参考资料与学习提示

Pixel Recurrent Neural Networks 是提出了PixelCNN的文章。当然,这篇文章主要是在讲PixelRNN,只想学PixelCNN的话通读这篇文章的价值不大。

Conditional Image Generation with PixelCNN Decoders 是提出Gated PixelCNN的文章。可以主要阅读消除视野盲区和门激活函数的部分。

PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications 是提出PixelCNN++的文章。整篇文章非常简练,可以整体阅读一遍,并且着重阅读离散logistic混合似然的部分。不过,这篇文章有很多地方写得过于简单了,连公式里的字母都不好好交代清楚,我还是看代码才看懂他们想讲什么。建议搭配本文的讲解阅读。

这几篇文章都使用了NLL(负对数似然)这个评价指标。实际上,这个指标就是对所有数据在模型里的平均出现概率取了个对数,加了个负号。对于PixelCNN,其NLL就是交叉熵损失函数。其他生成模型不是直接对数据的概率分布建模,它们的NLL不好求得。比如diffusion模型只能计算NLL的一个上界。

网上还有几份PyTorch代码复现供参考:

PixelCNN:https://github.com/singh-hrituraj/PixelCNN-Pytorch

Gated PixelCNN:https://github.com/anordertoreclaim/PixelCNN

附录:代码学习

在附录中,我将给出PixelCNN和Gated PixelCNN的PyTorch实现,并讲解PixelCNN++开源代码的实现细节。

PixelCNN 与 GatedPixelCNN

为了简化实现,我们来实现MNIST上的PixelCNN和Gated PixelCNN。MNIST是单通道数据集,我们不用考虑颜色通道之间复杂的约束。代码仓库:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/pixelcnn。

我们先准备好数据集。PyTorch的torchvision提供了获取了MNIST的接口,我们只需要用下面的函数就可以生成MNIST的Dataset实例。参数中,root为数据集的下载路径,download为是否自动下载数据集。令download=True的话,第一次调用该函数时会自动下载数据集,而第二次之后就不用下载了,函数会读取存储在root里的数据。

1
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)

我们可以用下面的代码来下载MNIST并输出该数据集的一些信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torchvision
from torchvision.transforms import ToTensor
def download_dataset():
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)
print('length of MNIST', len(mnist))
id = 4
img, label = mnist[id]
print(img)
print(label)

# On computer with monitor
# img.show()

img.save('work_dirs/tmp.jpg')
tensor = ToTensor()(img)
print(tensor.shape)
print(tensor.max())
print(tensor.min())

if __name__ == '__main__':
import os
os.makedirs('work_dirs', exist_ok=True)
download_dataset()

执行这段代码,输出大致为:

1
2
3
4
5
6
length of MNIST 60000
<PIL.Image.Image image mode=L size=28x28 at 0x7FB3F09CCE50>
9
torch.Size([1, 28, 28])
tensor(1.)
tensor(0.)

第一行输出表明,MNIST数据集里有60000张图片。而从第二行和第三行输出中,我们发现每一项数据由图片和标签组成,图片是大小为28x28的PIL格式的图片,标签表明该图片是哪个数字。我们可以用torchvision里的ToTensor()把PIL图片转成PyTorch张量,进一步查看图片的信息。最后三行输出表明,每一张图片都是单通道图片(灰度图),颜色值的取值范围是0~1。

我们可以查看一下每张图片的样子。如果你是在用带显示器的电脑,可以去掉img.show那一行的注释,直接查看图片;如果你是在用服务器,可以去img.save的路径里查看图片。该图片的应该长这个样子:

我们可以用下面的代码预处理数据并创建DataLoader。PixelCNN对输入图片的颜色取值没有特别的要求,我们可以不对图片的颜色取值做处理,保持取值范围在0~1即可。

1
2
3
4
5
6
from torch.utils.data import DataLoader

def get_dataloader(batch_size: int):
dataset = torchvision.datasets.MNIST(root='./data/mnist',
transform=ToTensor())
return DataLoader(dataset, batch_size=batch_size, shuffle=True)

准备好数据后,我们来实现PixelCNN和Gated PixelCNN。先从PixelCNN开始。

实现PixelCNN,最重要的是实现掩码卷积。其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2] = 1
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。因此,掩码卷积类里需要实现一个普通卷积的操作。实现普通卷积,既可以写成继承nn.Conv2d,也可以把nn.Conv2d的实例当成成员变量。这份代码使用了后一种实现方法。在__init__里把其他参数原封不动地传给self.conv,并在forward中直接调用self.conv(x)

1
2
3
4
5
6
7
8
9
10
11
12
class MaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
...
self.conv = nn.Conv2d(*args, **kwags)
...

def forward(self, x):
...
conv_res = self.conv(x)
return conv_res

准备好卷积对象后,我们来维护掩码张量。由于输入输出都是单通道图像,按照正文中关于PixelCNN的描述,我们只需要在卷积核的h, w两个维度设置掩码。我们可以用下面的代码生成一个形状为(H, W)的掩码并根据卷积类型对掩码赋值:

1
2
3
4
5
6
7
8
9
10
def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
...
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2] = 1
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1

然后,为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作。

1
mask = mask.reshape((1, 1, H, W))

在初始化函数的最后,我们把用PyTorch API把mask注册成名为'mask'的成员变量。register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中。这样做的好处时,每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。否则的话我们要手动model.mask = model.mask.to(device)转设备。register_buffer的第三个参数表示被注册的变量是否要加入state_dict中以保存下来。由于这里mask每次都会自动生成,我们不需要把它存下来,可以令第三个参数为False

1
self.register_buffer('mask', mask, False)

在前向传播时,只需要先让卷积核乘掩码,再做普通的卷积。

1
2
3
4
def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来。

我们先照着论文实现残差块ResidualBlock。原论文并没有使用归一化,但我发现使用归一化后效果会好一点,于是往模块里加了BatchNorm。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class ResidualBlock(nn.Module):

def __init__(self, h, bn=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(2 * h, h, 1)
self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
self.conv3 = nn.Conv2d(h, 2 * h, 1)
self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

def forward(self, x):
y = self.relu(x)
y = self.conv1(y)
y = self.bn1(y)
y = self.relu(y)
y = self.conv2(y)
y = self.bn2(y)
y = self.relu(y)
y = self.conv3(y)
y = self.bn3(y)
y = y + x
return y

有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class PixelCNN(nn.Module):

def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
super().__init__()
self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
self.residual_blocks = nn.ModuleList()
for _ in range(n_blocks):
self.residual_blocks.append(ResidualBlock(h, bn))
self.relu = nn.ReLU()
self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
self.out = nn.Conv2d(linear_dim, color_level, 1)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
for block in self.residual_blocks:
x = block(x)
x = self.relu(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.out(x)
return x

PixelCNN实现完毕,我们来按照同样的流程实现Gated PixelCNN。首先,我们要实现其中的垂直掩码卷积和水平掩码卷积,二者的实现和PixelCNN里的掩码卷积差不多,只是mask的内容不太一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class VerticalMaskConv2d(nn.Module):

def __init__(self, *args, **kwags):
super().__init__()
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2 + 1] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res


class HorizontalMaskConv2d(nn.Module):

def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)

def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res

水平卷积其实只要用一个1x3的卷积就可以实现了。但出于偷懒(也为了方便理解),我还是在3x3卷积的基础上添加的mask

之后我们来用两种卷积搭建论文中的Gated Block。

Gated Block搭起来稍有难度。如上面的结构图所示,我们主要要维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。v会经过一个垂直掩码卷积和一个门激活函数。h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class GatedBlock(nn.Module):

def __init__(self, conv_type, in_channels, p, bn=True):
super().__init__()
self.conv_type = conv_type
self.p = p
self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)
self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
1)
self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_output_conv = nn.Conv2d(p, p, 1)
self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()

def forward(self, v_input, h_input):
v = self.v_conv(v_input)
v = self.bn1(v)
v_to_h = v[:, :, 0:-1]
v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
v_to_h = self.v_to_h_conv(v_to_h)
v_to_h = self.bn2(v_to_h)

v1, v2 = v[:, :self.p], v[:, self.p:]
v1 = torch.tanh(v1)
v2 = torch.sigmoid(v2)
v = v1 * v2

h = self.h_conv(h_input)
h = self.bn3(h)
h = h + v_to_h
h1, h2 = h[:, :self.p], h[:, self.p:]
h1 = torch.tanh(h1)
h2 = torch.sigmoid(h2)
h = h1 * h2
h = self.h_output_conv(h)
h = self.bn4(h)
if self.conv_type == 'B':
h = h + h_input
return v, h

代码中的其他地方都比较常规,唯一要注意的是vh的合成部分。这一部分的实现初看下来比较难懂。为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位,而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)。

1
2
v_to_h = v[:, :, 0:-1]
v_to_h = F.pad(v_to_h, (0, 0, 1, 0))

为什么实际上是要对特征图v下移一个单位呢?实际上,在拼接vh时,我们是想做下面这个计算:

1
2
3
for i in range(H):
for j in range(W):
h[:, :, i, j] += v[:, :, i - 1, j]

但是,写成循环就太慢了,我们最好是能做向量化计算。注意到,vi相加的位置只差了一个单位。为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。

除了vh的合成有点麻烦外,GatedBlock还有一个细节值得注意。h的计算通路中有一个残差连接,但是,在网络的第一层,每个数据是不能看到自己的。所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。

最后,我们来用GatedBlock搭出Gated PixelCNN。Gated PixelCNN和原版PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GatedPixelCNN(nn.Module):

def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__()
self.block1 = GatedBlock('A', 1, p, bn)
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(GatedBlock('B', p, p, bn))
self.relu = nn.ReLU()
self.linear1 = nn.Conv2d(p, linear_dim, 1)
self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
self.out = nn.Conv2d(linear_dim, color_level, 1)

def forward(self, x):
v, h = self.block1(x, x)
for block in self.blocks:
v, h = block(v, h)
x = self.relu(h)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.out(x)
return x

准备好了模型代码,我们可以编写训练和采样的脚本了。我们先用超参数初始化好两个模型。根据论文的描述,PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from dldemos.pixelcnn.dataset import get_dataloader, get_img_shape
from dldemos.pixelcnn.model import PixelCNN, GatedPixelCNN

import torch
import torch.nn as nn
import torch.nn.functional as F

import time
import einops
import cv2

import numpy as np
import os

batch_size = 128
color_level = 8 # or 256

models = [
PixelCNN(15, 128, 32, True, color_level),
GatedPixelCNN(15, 128, 32, True, color_level)
]

if __name__ == '__main__':
os.makedirs('work_dirs', exist_ok=True)
model_id = 1
model = models[model_id]
device = 'cuda'
model_path = f'dldemos/pixelcnn/model_{model_id}_{color_level}.pth'
train(model, device, model_path)
sample(model, device, model_path,
f'work_dirs/pixelcnn_{model_id}_{color_level}.jpg')

之后是训练部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def train(model, device, model_path):
dataloader = get_dataloader(batch_size)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
loss_fn = nn.CrossEntropyLoss()
n_epochs = 40
tic = time.time()
for e in range(n_epochs):
total_loss = 0
for x, _ in dataloader:
current_batch_size = x.shape[0]
x = x.to(device)
y = torch.ceil(x * (color_level - 1)).long()
y = y.squeeze(1)
predict_y = model(x)
loss = loss_fn(predict_y, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * current_batch_size
total_loss /= len(dataloader.dataset)
toc = time.time()
torch.save(model.state_dict(), model_path)
print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')
print('Done')

这部分代码十分常规,和普通的多分类任务十分类似。代码中值得一看的是下面几行:

1
2
3
4
y = torch.ceil(x * (color_level - 1)).long()
y = y.squeeze(1)
predict_y = model(x)
loss = loss_fn(predict_y, y)

这几行代码根据输入x得到了标签y,再做前向传播,最后用预测的predict_yy求交叉熵损失函数。这里第一个要注意的地方是y = y.squeeze(1)这一行。在PyTorch中用交叉熵函数时,标签的形状应该为[N, A, B, ...],预测值的形状应为[N, num_class, A, B, ...]。其中,A,B, ...表示数据的形状。在我们的任务中,数据是二维的,因此标签的形状应为[N, H, W],预测值的形状应为[N, num_class, H, W]。而我们在DataLoader中获得的数据的形状是[N, 1, H, W]。我们要对数据y的形状做一个变换,使之满足PyTorch的要求。这里由于输入是单通道,我们可以随便用squeeze()y长度为1的通道去掉。如果图像是多通道的话,我们则不应该修改y,而是要对预测张量y_predict做一个reshape,改成[N, num_class, C, H, W]

第二个要注意的是y = torch.ceil(x * (color_level - 1)).long()这一行。为什么需要写一个这么复杂的浮点数转整数呢?这个地方的实现需要多解释几句。在我们的代码中,PixelCNN的输入可能来自两个地方:

  1. 训练时,PixelCNN的输入来自数据集。数据集里的颜色值是0~1的浮点数。
  2. 采样时,PixelCNN的输入来自PixelCNN的输出。PixelCNN的输出是整型(别忘了,PixelCNN只能产生离散的输出)。

两种输入,一个是0~1的浮点数,一个是0~color_level-1的整数。为了统一两个输入的形式,最简单的做法是对整型颜色输入做个除法,映射到0~1里,把它统一到浮点数上。

此外,还有一个地方需要类型转换。在训练时,我们需要得到每个像素的标签,即得到每个像素颜色的真值。由于PixelCNN的输出是离散的,这个标签也得是一个离散的颜色。而标签来自训练数据,训练数据又是0~1的浮点数。因此,在计算标签时,需要做一次浮点到整型的转换。这样,整个项目里就有两个重要的类型转换:一个是在获取标签时把浮点转整型,一个是在采样时把整型转浮点。这两个类型转换应该恰好「互逆」,不然就会出现转过去转不回来的问题。

在项目中,我使用了下图所示的浮点数映射到整数的方法。0.0映射到0,(0, 1/255]映射到1,……(254/255, 1]映射到255。即浮点转整型时使用ceil(x*255),整型转浮点的时候使用x/255。这种简单的转换方法保证一个区间里的离散颜色值只会映射到一个整数上,同时把整数映射回浮点数时该浮点数也会落在区间里。如果你随手把浮点转整型写成了int(x*255),则会出现浮点转整数和整数转浮点对应不上的问题,到时候采样的结果会很不好。

由于一个整型只能映射到一个浮点数,而多个浮点数会映射到一个整数,严格来说,大部分浮点数转成整数再转回来是变不回原来的浮点数的。这两个转换过程从数学上来说不是严格的互逆。但是,如果我们马虎一点,把位于同一个区间的浮点数看成等价的,那么浮点数和整数之间的映射就是一个双射,来回转换不会有任何信息损失。

刚才代码中y = torch.ceil(x * (color_level - 1)).long()这一行实际上就是在描述怎样把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的。

再来看看采样部分的代码。和正文里的描述一样,在采样时,我们把x初始化成一个0张量。之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def sample(model, device, model_path, output_path, n_sample=81):

model.eval()
model.load_state_dict(torch.load(model_path))
model = model.to(device)
C, H, W = get_img_shape() # (1, 28, 28)
x = torch.zeros((n_sample, C, H, W)).to(device)
with torch.no_grad():
for i in range(H):
for j in range(W):
output = model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist,
1).float() / (color_level - 1)
x[:, :, i, j] = pixel

imgs = x * 255
imgs = imgs.clamp(0, 255)
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(n_sample**0.5))

imgs = imgs.detach().cpu().numpy().astype(np.uint8)

cv2.imwrite(output_path, imgs)

整个采样代码的核心部分是下面这几行。我们先获取模型的输出,再用softmax转换成概率分布,再用torch.multinomial(prob_dist, 1)从概率分布里采样出一个0~(color_level-1)的离散颜色值,再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色),最后把新像素填入生成图像。

1
2
3
4
5
output = model(x)
prob_dist = F.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist,
1).float() / (color_level - 1)
x[:, :, i, j] = pixel

上面的代码中,如前所述,/ (color_level - 1)与前面的torch.ceil(x * (color_level - 1)).long()必须是对应起来的。两个操作必须「互逆」,不然就会出问题。

当然,最后得到的图像x是一个用0~1浮点数表示的图像,可以直接把它乘255变成一个用8位字节表示的图像,这一步浮点到整型的转换是为了让图像输出,和其他图像任务的后处理是一样的,和PixelCNN对于离散颜色和连续颜色的建模不是同一个意思,不是非得取一次ceil()

PixelCNN训练起来很慢。在代码中,我默认训练40个epoch。原版PixelCNN要花一小时左右训完,Gated PixelCNN就更慢了。

以下是我得到的一些采样结果。首先是只有8个颜色级别的PixelCNN和Gated PixelCNN。

可以看出,PixelCNN经常会生成一些没有意义的「数字」,而Gated PixelCNN生成的大部分数字都是正常的。但由于颜色级别只有8,模型偶尔会生成较粗的色块。这个在Gated PixelCNN的输出里比较明显。

之后看一下正常的256个颜色级别的PixelCNN和Gated PixelCNN采样结果。

由于颜色级别增大,任务难度变大,这两个模型的生成效果就不是那么好了。当然,Gated PixelCNN还是略好一些。训练效果差,与MNIST的特性(大部分像素都是0和255)以及PixelCNN对于离散颜色的建模有关。PixelCNN的这一缺陷已经在PixelCNN++论文里分析过了。

PixelCNN++ 源码阅读

PixelCNN++在实现上细节颇多,复现起来难度较大。而且它的官方实现是拿TensorFlow写的,对于只会PyTorch的选手来说不够友好。还好,PixelCNN++的官方实现非常简练,核心代码只有两个文件,没有过度封装,也没有过度使用API,哪怕不懂TensorFlow也不会有障碍(但由于代码中有很多科学计算,阅读起来没有障碍,却难度不小)。让我们来通过阅读官方源码来学习PixelCNN++的实现。

官方代码的地址在 https://github.com/openai/pixel-cnn 。源码有两个核心文件:nn.py实现了网络模块及一些重要的训练和采样函数,model.py定义了网络的结构。让我们自顶向下地学习,先看model.py,看到函数调用后再跑到nn.py里查看实现细节。

model.py里就只有一个函数model_spec,它定义了神经网络的结构。
它的参数为:

1
2
3
4
5
6
7
8
9
10
def model_spec(x, 
h=None,
init=False,
ema=None,
dropout_p=0.5,
nr_resnet=5,
nr_filters=160,
nr_logistic_mix=10,
resnet_nonlinearity='concat_elu',
energy_distance=False):

各参数的意义为:

  • x: 形状为[N, H, W, D1]的输入张量。其中,D1表示输入通道数。对于RGB图像,D1=3
  • h: 形状为[N, K]的约束条件,即对于每个batch来说,约束条件是一个长度K的向量。这里的约束条件和Gatd PixelCNN中提出的一样,可以是文字,也可以是类别,只要约束条件最终被转换成一个向量就行。
  • init: 是否执行初始化。这和TensorFlow的实现有关,可以不管。
  • ema: 对参数使用指数移动平均,一种训练优化技巧,和论文无关,可以不管。
  • dropout_p: dropout的概率。
  • nr_resnet: U-Net每一块里有几个ResNet层(U-Net一共有6块,编码器3块解码器3块)。
  • nr_filters: 每个卷积层的卷积核个数,即所有中间特征图的通道数。
  • nr_logistic_mix: 论文里的$K$,表示用几个logistic分布混合起来描述一个颜色分布。
  • resnet_nonlinearity: 激活函数的类别。
  • energy_distance:是否使用论文里没提过的一种算损失函数的办法,可以不管。

之后来看函数体。20行with arg_scope ([nn.conv2d, ...], counters=counters, ...)大概是说进入了TensorFlow里的arg_scope这个上下文。只要在上下文里,后面counters等参数就会被自动传入nn.conv2d等函数,而不需要在函数里显式传参。这样写会让后面的函数调用更简短一点。

22行至30行在选择激活函数,可以直接跳过。

1
2
3
4
5
6
7
8
9
# parse resnet nonlinearity argument
if resnet_nonlinearity == 'concat_elu':
resnet_nonlinearity = nn.concat_elu
elif resnet_nonlinearity == 'elu':
resnet_nonlinearity = tf.nn.elu
elif resnet_nonlinearity == 'relu':
resnet_nonlinearity = tf.nn.relu
else:
raise('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported')

从35行开始,函数正式开始定义网络结构。一开始,代码里有一个匪夷所思的操作:先是取出输入张量的形状xs,再根据这个形状给x填充了一个全是1的通道。

1
2
xs = nn.int_shape(x)
x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

虽然作者加了注释,说这个x_pad后面会用到。但我翻遍了代码,楞是没找到这个多出来的通道发挥了什么作用。GitHub issue里也有人提问,问这个x_pad在做什么。有其他用户给了回复,说他尝试了去掉填充,结果不变。可见这一行代码确实是毫无贡献,还增加了不必要的计算量。大概是作者没删干净过时的实现代码。

之后的几行是在初始化上卷积和左上卷积的中间结果(上卷积和Gated PixelCNN里的垂直卷积类似,左上卷积和Gated PixelCNN里的水平卷积类似)。u_list会保存所有上卷积在编码器里的结果,ul_list会保存所有左上卷积在编码器里的结果。这些结果会供解码器使用。

1
2
3
4
5
6
7
8
9
10
11
12
 u_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[2, 3])
)] # stream for pixels above
ul_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[1,3])
) + nn.right_shift(
nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1])
)] # stream for up and to the left

作者没有使用带掩码的卷积,而是通过普通卷积加偏移等效实现了掩码卷积。这一实现非常巧妙,效率更高。我们来看看这几个卷积的实现方法。

首先看上卷积down_shifted_conv2d,它表示实现一个卷积中心在卷积核正下方的卷积。作者使用了[2,3]的卷积核,并手动给卷积填充(注意,卷积的类型是'valid'不是'same')。这种卷积等价于我们做普通的3x3卷积再给上面6个像素打上掩码。

1
2
3
def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs):
x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]])
return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)

作者在down_shifted_conv2d之后跟了一个down_shift。这个操作和我们实现Gated PixelCNN时移动v_to_h张量的做法一样,去掉张量最下面一行,在最上面一行填0,也就是让张量往下移了一格。

1
2
3
def down_shift(x):
xs = int_shape(x)
return tf.concat([tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]],1)

类似地,在做第一次左上卷积时,作者把一个下移过的1x3卷积结果和一个右移过的2x1卷积结果拼到了一起。其中,down_right_shifted_conv2d就是实现一个卷积中心在卷积核右下角的卷积。

1
2
3
4
5
6
7
ul_list = [nn.down_shift(
nn.down_shifted_conv2d(x_pad,
num_filters=nr_filters,
filter_size=[1,3])
) + nn.right_shift(
nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1])
)]

初始化完毕后,数据就正式进入了U-Net。让我们先略过函数的细节,看一看模型的整体架构。在下采样部分,三级U-Net在每一级都是先经过若干个gated_resnet模块,再下采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

for rep in range(nr_resnet):
u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

之后是上采样。类似地,数据先经过若干个gated_resnet模块,再上采样。与前半部分不同的是,前半部分的输出会从u_listul_list中逐个取出(实际上这两个list起到了一个栈的作用),接入到gated_resnet的输入里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
u = u_list.pop()
ul = ul_list.pop()
for rep in range(nr_resnet):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

for rep in range(nr_resnet+1):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

for rep in range(nr_resnet+1):
u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
tf.add_to_collection('checkpoints', u)
tf.add_to_collection('checkpoints', ul)

模型U-Net的部分到此为止。整个网络的结构并不复杂,我们只要看懂了nn.gated_resnet的实现,就算理解了整个模型的实现。让我们来详细看一下这个模块是怎么实现的。以下是整个模块的实现代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
xs = int_shape(x)
num_filters = xs[-1]

c1 = conv(nonlinearity(x), num_filters)
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
c2 = conv(c1, num_filters * 2, init_scale=0.1)

# add projection of h vector if included: conditional generation
if h is not None:
with tf.variable_scope(get_name('conditional_weights', counters)):
hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
if init:
hw = hw.initialized_value()
c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])

a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)
return x + c3

照例,我们来先看一下函数的每个参数的意义。

1
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs)
  • x: 模块的输入。
  • a: 模块的附加输入。附加输入有两个来源:上方u_list的信息传递给左上方ul_list的信息、编码器把信息传递给解码器。
  • h: 形状为[N, K]的约束条件。从模型的参数里传递而来。
  • nonlinearity: 激活函数。从模型的参数里传递而来。
  • conv:卷积操作的函数。可能是上卷积或者左上卷积。
  • init: 是否执行初始化。这和TensorFlow的实现有关,可以不管。
  • counters: 作者写的一个用于方便地给模块的命名的字典,可以不管。
  • ema: 对参数使用指数移动平均。从模型的参数里传递而来。
  • dropout_p: dropout的概率。从模型的参数里传递而来。

模块主要是做了下面这些卷积操作。一开始,先对输入x做卷积,得到c1。如果有额外输入a,则对a做一个1x1卷积(作者自己实现了1x1卷积,把函数命名为nin),加到c1上。做完第一个卷积后,过一个dropout层。最后再卷积一次,得到2*num_filters通道数的张量。

1
2
3
4
5
6
7
c1 = conv(nonlinearity(x), num_filters)
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
c2 = conv(c1, num_filters * 2, init_scale=0.1)

之后,作者也使用了一种门结构作为整个模块的激活函数。但是和Gated PixelCNN相比,PixelCNN++的门结构简单一点。详见下面的代码。

1
2
a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)

最后输出时,c3和输入x之间有一个残差连接。

1
return x + c3

看完gated_resnet的实现,我们可以跳回去继续看模型结构了。经过了U-Net的主体结构后,只需要经过一个输出层就可以得到最终的输出了。输出层里,作者用1x1卷积修改了输出通道数,令最后的通道数为10*nr_logistic_mix

1
2
3
4
5
6
7
8
9
if energy_distance:
# 跳过
else:
x_out = nn.nin(tf.nn.elu(ul),10*nr_logistic_mix)

assert len(u_list) == 0
assert len(ul_list) == 0

return x_out

大家还记得这个10是从哪里来的吗?在正文中,我们曾经学过,对于某个像素的第$i$个logistic分布,网络会输出10个参数:$\pi, \mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$。这个10就是10个参数的意思。

光知道一共有10个参数还不够。接下来就是PixelCNN++比较难懂的部分——怎么用这些参数构成一共logistic分布,并从连续分布中得到离散的概率分布。这些逻辑被作者写在了损失函数nn.discretized_mix_logistic_loss里面。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def discretized_mix_logistic_loss(x,l,sum_all=True):
""" log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:,:,:,:nr_mix]
l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
means = l[:,:,:,:,:nr_mix]
log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])
x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)
centered_x = x - means
inv_stdv = tf.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = tf.nn.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = tf.nn.sigmoid(min_in)
log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5))))

log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)
if sum_all:
return -tf.reduce_sum(log_sum_exp(log_probs))
else:
return -tf.reduce_sum(log_sum_exp(log_probs),[1,2])

这个函数很长,很难读。它实际上可以被拆成四个部分:取参数、求均值、求离散概率、求和。让我们一部分一部分看过来。

首先是取参数部分,这部分代码如下所示。模型一共输出了10*nr_mix个参数,即输出了nr_mix组参数,每组有10个参数。如前所述,第一个参数是选择该分布的未经过softmax的概率logit_probs,之后的6个参数是三个通道的均值及三个通道的标准差取log,最后3个参数是描述通道间依赖关系的$\alpha, \beta, \gamma$。不用去认真阅读这段代码,只需要知道这些代码可以把数据取出来即可。

1
2
3
4
5
6
7
8
xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
logit_probs = l[:,:,:,:nr_mix]
l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
means = l[:,:,:,:,:nr_mix]
log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])

之后是求均值部分。在第一行,作者用了一种曲折的方式实现了repeat操作,把x在最后一维重复了nr_mix次,方便后续处理。在第二第三行,作者根据论文里的公式,调整了G通道和B通道的均值。在最后第四行,作者把所有均值张量拼到了一起。

1
2
3
4
x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)

再来是求离散概率部分。作者根据论文里的公式,算出了当前离散分布的积分上限和积分下限(通过从累计分布密度函数里取值),再做差,得到了离散分布的概率。由于最终的概率值要求log,作者没有按照公式的顺序先算累计分布概率函数的值,再取log,而是把所有计算放到一起并化简。这样代码虽然难读了一点,但减少了不必要的计算,也减少了精度损失。

1
2
3
4
5
6
7
8
9
 centered_x = x - means
inv_stdv = tf.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1./255.)
cdf_plus = tf.nn.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1./255.)
cdf_min = tf.nn.sigmoid(min_in)
log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
cdf_delta = cdf_plus - cdf_min # probability for all other cases

作者还算了积分区间中心的概率,以处理某些边界情况。实际上这个值没有在代码中使用。

1
2
3
mid_in = inv_stdv * centered_x
log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in)
# log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

光做差还不够。为了处理颜色值在0和255的边界情况,作者还给代码加入了一些边界上的特判,才得到了最终的概率log_probs
1
2
3
4
5
log_probs = tf.where(x < -0.999, log_cdf_plus, 
tf.where(x > 0.999, log_one_minus_cdf_min,
tf.where(cdf_delta > 1e-5,
tf.log(tf.maximum(cdf_delta, 1e-12)),
log_pdf_mid - np.log(127.5))))

最后是loss求和部分。除了要把离散概率的对数求和外,还要加上选择这个分布的概率的对数。log_prob_from_logits就是做一个softmax再求一个log。算上了选择分布的概率后,再对loss求一次和,就得到了最终的loss。

1
2
3
4
5
log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)
if sum_all:
return -tf.reduce_sum(log_sum_exp(log_probs))
else:
return -tf.reduce_sum(log_sum_exp(log_probs),[1,2])

至此,我们就看完了训练部分的关键代码。我们再来看一看采样部分最关键的代码,怎么从logisitc分布里采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def sample_from_discretized_mix_logistic(l,nr_mix):
ls = int_shape(l)
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])
# sample mixture indicator from softmax
sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])
# select logistic parameters
means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)
log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)
coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))
x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)
x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)
x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)
return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)

一开始,还是和刚刚的求loss一样,作者把参数从网络输出l里拆出来。logit_probs是选择某分布的未经softmax的概率,其余的参数是均值、标准差、通道间依赖参数。

1
2
3
4
5
6
def sample_from_discretized_mix_logistic(l,nr_mix):
ls = int_shape(l)
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3])

之后,作者对logit_probs做了一个softmax,得到选择各分布的概率。之后,作者根据这个概率分布采样,从nr_mix个logistic分布里选了一个做为这次生成使用的分布。作者没有使用下标来选择数据,而是把选中的序号编码成one-hot向量sel,通过乘one-hot向量来实现从某数据组里取数。

1
2
sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32)
sel = tf.reshape(sel, xs[:-1] + [1,nr_mix])

接着,作者根据sel,取出nr_mix个logistic分布中某一个分布的均值、标准差、依赖系数。

1
2
3
4
# select logistic parameters
means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4)
log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.)
coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4)

再然后,作者用下面两行代码完成了从logistic分布的采样。从一个连续概率分布里采样是一个基础的数学问题。其做法是先求概率分布的累计分布函数。由于累计分布函数可以把自变量一一映射到0~1之间的概率,我们就得到了一个0~1之间的数到自变量的映射,即累积分布函数的反函数。通过对0~1均匀采样,再套入累积分布函数的反函数,就完成了采样。下面第二行计算其实就是在算logisitc分布的累积分布函数的反函数的一个值。

1
2
u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5)
x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u))

只从分布里采样还不够,我们还得算上依赖系数。把依赖系数的贡献算完后,整个采样就结束了,我们得到了RGB三个颜色值。

1
2
3
4
x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.)
x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.)
x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.)
return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3)

至此,PixelCNN++中最具有学习价值的代码就看完了。让我再次总结一下PixelCNN++中的重要代码,并介绍一下学习它们需要什么前置知识。

PixelCNN++中第一个比较重要的地方是掩码卷积的实现。它没有真的使用到掩码,而是使用了卷积中心在卷积核下方和右下角的卷积来等价实现。要读懂这些代码,你需要先看懂PixelCNN和Gated PixelCNN里面对于掩码卷积的定义,知道PixelCNN++为什么要做两种卷积。之后,你还需要对卷积操作有一点基础的认识,知道卷积操作的填充方式其实是在改变卷积中心在卷积核中的位置。你不需要懂太多TensorFlow的知识,毕竟卷积的API就那么几个参数,每个框架都差不多。

PixelCNN++的另一个比较重要的地方是logistic分布的离散概率计算与采样。为了学懂这些,你需要一点比较基础的统计学知识,知道概率密度函数与累积分布函数的关系,知道怎么用计算机从一个连续分布里采样。之后,你要读懂PixelCNN++是怎么用logistic分布对离散概率建模的,知道logistic分布的累计分布函数就是sigmoid函数。懂了这些,你看代码就不会有太多问题,代码基本上就是对论文内容的翻译。反倒是如果读论文没读懂,可以去看代码里的实现细节。

最近,我需要在Python里使用PatchMatch算法(一种算两张图片逐像素匹配关系的算法)。我去网上搜了一份实现,跑了下测试程序,发现它跑边长300像素的图片都要花三分钟。这个速度实在太慢了。我想起以前瞟到过一篇介绍Python加速的文章,里面提到过Numba这个库。于是,我现学现用,最终成功用Numba让原来要跑180秒的程序在0.6秒左右跑完。可见,Numba学起来是很快的。在这篇文章中,我将以这个Python版PatchMatch项目为例,介绍如何快速从零上手Numba,以大幅加速Python科学计算程序。这篇文章不会涉及PatchMatch算法的原理,只要你写过Python,就能读懂本文。

缘起:一份缓慢的 PatchMatch 实现

PatchMatch是Adobe提出的一种快速计算两张图片逐像素匹配关系的算法。也就是说,输入两张类似的图片A和B(比如视频里的连续两帧),算法能输出图片A中的每个像素对应B中的哪个像素(可能会出现多对一的情况)。为了快速验证算法的效果,我们可以输入图片A和B,用算法获取A到B的匹配关系,再根据匹配关系从B中取像素重建A。如果重建出来的图片和原来的图片A看上去差不多,那算法的效果就很不错。下图是一份PatchMatch测试程序的输出。

我在GitHub上找到了一份简明实用的Python版PatchMatch实现,得到了上面的输出结果。结果是挺不错,但哪怕是跑326x244这么小的图片,都要花约180秒才能跑完。

我懒得从头学一遍PatchMatch,决定直接上手优化代码。代码不长,其函数调用关系能很快理清。

代码入口函数是NNS()。它先是调用了initialization(),再循环itr次,每次遍历所有像素,对每个像素调用propagation()random_search()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
f, dist, img_padding = initialization(img, ref, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, False)
random_search(f, a, dist, img_padding, ref, p_size)
else:
for i in range(A_h):
for j in range(A_w):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, True)
random_search(f, a, dist, img_padding, ref, p_size)
return f

initialization()先是定义了一些变量,之后对所有像素调用cal_distance()

1
2
3
4
5
6
7
8
9
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
...
for i in range(A_h):
for j in range(A_w):
...
dist[i, j] = cal_distance(a, b, A_padding, B, p_size)
return f, dist, A_padding

propagation()主要调用了一次cal_distance()

1
2
3
4
5
6
7
8
9
10
11
12
def propagation(f, a, dist, A_padding, B, p_size, is_odd):
...
if is_odd:
if idx == 1:
...
dist[x, y] = cal_distance(a, f[x, y], A_padding, B, p_size)
if idx == 2:
...
dist[x, y] = cal_distance(a, f[x, y], A_padding, B, p_size)
else:
# 和 is_odd 时类似
...

random_search()则主要是在一个while循环里反复调用cal_distance()

1
2
3
4
5
6
def random_search(f, a, dist, A_padding, B, p_size, alpha=0.5):
...
while search_h > 1 and search_w > 1:
...
d = cal_distance(a, b, A_padding, B, p_size)
...

最后来看被调用最多的cal_distance()。这个函数用于计算图片A,B之间的某个距离。也别管这个距离是什么意思,总之是这一个有点耗时的计算。

1
2
3
4
5
6
7
8
9
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
temp = patch_b - patch_a
num = np.sum(1 - np.int32(np.isnan(temp)))
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

至此,这份程序就差不多看完了。可以发现,代码大部分时候都在遍历像素,且遍历每个像素时多次调用cal_distance()函数。而我们知道,拿Python本身做计算是很慢的,尤其是在一个很长的循环里反复计算。这份代码性能较低,正是因为代码在遍历每个像素时做了大量计算。

我以前看过一篇文章,说Numba库能够加速Python科学计算程序,尤其是加速带有大量循环的程序。于是,我去学习了一下Numba的基础用法。

Numba 基础

Numba的官方文档提供了非常友好的入门教程。我们来大致把教程过一下。

Numba可以用pip一键安装。

1
pip install numba

Numba尤其擅长加速循环以及和NumPy相关的计算。使用@jit(nopython=True)(或@njit)装饰一个函数后,我们可以在这个函数里随便写循环,随便用NumPy计算,就像在用C语言一样。经Numba优化后,这个函数会跑得飞快。以下是官方给出的入门示例程序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from numba import jit, njit
import numpy as np

x = np.arange(100).reshape(10, 10)


@jit(nopython=True) # 设置 "nopython" 模式以获取最优性能,等价于 @njit
def go_fast(a): # 初次调用时函数将被编译成机器码
trace = 0.0
for i in range(a.shape[0]): # Numba 喜欢循环
trace += np.tanh(a[i, i]) # Numba 喜欢 NumPy 函数
return a + trace # Numba 喜欢 NumPy 广播


print(go_fast(x))

Numba是怎么完成加速的呢?从装饰器名jit(JIT,Just-In-Time Compiler的简称)中,我们能猜出,Numba使用了即时编译技术,把函数直接翻译成了机器码,而没有像普通Python程序一样解释执行。Numba有两种编译模式,最常见的模式是令参数nopython=True,在编译中完全不用Python解释器。这种模式下,函数能以最优性能翻译成机器码。

修改上面的代码,我们可以测试该函数的速度。注意,由于采用了即时编译,函数在初次调用时会被编译。如果只要计算函数在编译后的运行时间,应该从第二次调用后开始计时。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from numba import jit
import numpy as np
import time

x = np.arange(100).reshape(10, 10)


@jit(nopython=True)
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace


# 不要汇报这个速度,因为编译时间也被算进去了
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
print("Elapsed (with compilation) = {}s".format((end - start)))

# 现在函数已经被编译了,用缓存好的函数重新计时
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
print("Elapsed (after compilation) = {}s".format((end - start)))

我们可以得到类似于下面的输出:

1
2
Elapsed (with compilation) = 1.0579542s
Elapsed (after compilation) = 1.7699999999898353e-05s

我们可以尝试一下不用Numba,直接用Python循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
import time

x = np.arange(100).reshape(10, 10)


def go_slowly(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace


start = time.perf_counter()
go_slowly(x)
end = time.perf_counter()
print("Elapsed (without Numba) = {}s".format((end - start)))

这个速度(4e-4)比用Numba慢了一个数量级。

1
Elapsed (without Numba) = 0.00046979999999985367s

也就是说,我们只要在普通的Python计算函数上加一个@jit(nopython=True)(或@njit),其他什么都不用做,就可以加速代码了。让我们来用它改进一下之前的PatchMatch程序。

用Numba计时编译加速PatchMatch

让我们开始做PatchMatch的性能调优。首先,根据性能优化的一般做法,我们要得知每一行函数调用的运行时间,找到性能瓶颈,从瓶颈处开始优化。我们可以用line_profiler来分析每一行代码的运行时间。用pip即可安装这个库。

1
pip install line_profiler

把主函数修改一下,在调用算法入口函数时拿LineProfiler封装一下,再用lp.add_function添加想监控的函数,即可开始性能分析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if __name__ == "__main__":
img = np.array(Image.open("./cup_a.jpg"))
ref = np.array(Image.open("./cup_b.jpg"))
p_size = 3
itr = 5

# start = time.time()
# f = NNS(img, ref, p_size, itr)
# end = time.time()
# print(end - start)

# reconstruction(f, img, ref)

lp = LineProfiler()
lp_wrapper = lp(NNS)
lp.add_function(propagation)
lp.add_function(random_search)
f = lp_wrapper(img, ref, p_size, itr)
lp.print_stats()

性能分析结果会显示每一行代码的运行时间及占用时间百分比。从结果中可以看出,在入口函数NNS()中,random_search()最为耗时。这是符合预期的,因为random_search()里还有一层while循环。

现在,我们应该着重优化random_search()的性能。我们继续查看一下random_search()的性能分析结果。

结果显示,绝大多数时间都消耗在了while循环里。也和我们之前分析得一样,cal_distance()是耗时最多的一行。除了random_search()外,其他几个函数也多次调用了cal_distance()。因此,我们目前代码优化的目标就定格在了cal_distance()身上。

刚刚学完了Numba,这不正好可以用上了吗?我们可以尝试直接给cal_distance()加一个@njit装饰器。

1
2
3
4
5
6
7
8
9
@njit
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
temp = patch_b - patch_a
num = np.sum(1 - np.int32(np.isnan(temp)))
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

修改完代码后,再次运行程序。这次,程序报了一大堆错误。大致是说,在某一行碰到了Numba识别不了的函数。应该把np.int32()的强制类型转换改成.astype(np.int32)

1
2
num = np.sum(1 - np.int32(np.isnan(temp)))
^

改完之后,如果Numba版本较老,还会碰到新的报错:

1
Use of unsupported NumPy function 'numpy.nan_to_num' or unsupported use of the function.

报错显示,numpy.nan_to_num函数没有得到支持。再次翻阅Numba文档,可以发现,Numba并不支持所有NumPy函数。Numba对NumPy的支持情况可以在文档里查询(需要把文档切换到你当前Numba的版本)。

总之,cal_distance()这个函数不改不行了。得认真阅读一下这个函数的原理。原来,cal_distance(a, b, A_padding, B, p_size)函数是算图像A_padding和图像B中某一个像素块的均方误差的平均值,其中,像素块的边长为p_size,像素块在A_padding的坐标由a表示,在B中的坐标由b表示。

1
2
3
4
5
6
7
8
9
10
11
12
13
def cal_distance(a, b, A_padding, B, p_size):
p = p_size // 2
# 根据坐标a和边长p从A_padding里取像素块
patch_a = A_padding[a[0]:a[0] + p_size, a[1]:a[1] + p_size, :]
# 根据坐标b和边长p从A_padding里取像素块
patch_b = B[b[0] - p:b[0] + p + 1, b[1] - p:b[1] + p + 1, :]
# 求差
temp = patch_b - patch_a
# 根据非nan像素数量算有效像素数量
num = np.sum(1 - np.isnan(temp).astype(np.int32))
# 排除nan,求差的平方和,再除以有效像素数量
dist = np.sum(np.square(np.nan_to_num(temp))) / num
return dist

代码里还有一些奇怪的有关nan的运算:如果像素块里某处有nan,就说明此处像素无效,不应该参与均方误差的运算。为什么图像里会有nan呢?我们得阅读代码的其他部分。

nan是在初始化函数initialization()里加入的。A_padding原来是图像A在周围填了一圈nan的结果。我们大致能猜测出作者填充nan的原因:从A中取像素块时,若像素块在边缘,则有一些像素就不应该被计算了。拿条件语句判断这些无效像素比较麻烦,作者选择干脆在图像A周围填一圈nan,保证每次取像素块时不用判断无效像素。等算误差的时候再判断根据nan排除无效像素。

1
2
3
4
5
6
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
p = p_size // 2
A_padding = np.ones([A_h + p * 2, A_w + p * 2, 3]) * np.nan
A_padding[p:A_h + p, p:A_w + p, :] = A

使用nan填充,既耗时,兼容性又不好。为了尽可能加速cal_distance(),我把填充改成了edge填充,即让填充值等于边界值,并取消了无效像素的判断。也就是说,若像素块取到了图像外的像素,则认为这个像素和边界处的像素一样。这个假设是很合理的,这种修改几乎不会损耗算法的效果。

除此之外,为了进一步减少cal_distance()中的计算,我把要用到的变量都提前在外面算好再传进来。由于现在不需要考虑无效像素的数量,可以直接对误差求和,不用再算平均值,少做一次除法。还有,现在用@njit装饰了函数,可以放心大胆地在循环里做计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
p = p_size // 2
A_padding = np.pad(A, ((p, p), (p, p), (0, 0)), mode='edge')

# Numba 循环写法
@njit
def cal_distance(x, y, x2, y2, A_padding, B, p):
sum = 0
for i in range(p + p + 1):
for j in range(p + p + 1):
for k in range(3):
a = float(A_padding[x + i, y + j, k])
bb = B[x2 - p + i, y2 - p + j, k]
sum += (a - bb)**2
return sum

当然,用NumPy实现cal_distance也是可以的。

1
2
3
4
5
# NumPy 等价写法,加上@njit更快
def cal_distance(x, y, x2, y2, A_padding, B, p):
patch_a = A_padding[x:x + p, y:y + p, :].astype(np.float32)
patch_b = B[x2 - p:x2 + p + 1, y2 - p:y2 + p + 1, :]
return np.sum((patch_a - patch_b)**2)

经测试,把nan的判断全部去掉后,使用NumPy版的cal_distance(),程序的运行时间降到了60秒。给NumPy版的cal_distance()加上@njit,运行时间进一步降低到了33秒。而如果使用带@njit装饰的循环写法,则运行时间也差不多是33秒,甚至还略快一些。这些测试结果印证了Numba的特性:

  1. Numba可以加速和NumPy张量相关的计算
  2. 在Numba中使用循环不会降低运行速度

成功用@njit优化完了代码中最深层的cal_distance(),我们会想,是不是所有函数都可以用同样方法加速?我们可以来做个实验,给最外层的入口函数NNS()加上@njit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@njit
def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
f, dist, img_padding = initialization(img, ref, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, False)
random_search(f, a, dist, img_padding, ref, p_size)
else:
for i in range(A_h):
for j in range(A_w):
a = np.array([i, j])
propagation(f, a, dist, img_padding, ref, p_size, True)
random_search(f, a, dist, img_padding, ref, p_size)
return f

运行程序,会得到类似于下面的报错:

1
Untyped global name 'initialization': Cannot determine Numba type of <class 'function'>

把报错放网上一搜,原来,@njit的自定义函数只能调用加@njit的自定义函数。也就是说,在上面这份代码里,我们虽然用@njit装饰了NNS(),但我们自己定义的initialization(), propagation(),random_search()全部都没有用@njit装饰,因此NNS()的编译会出错。看来,我们得自底向上一步一步加上@njit了。

先来尝试修改一下initialization()。很可惜,直接加上@njit会报错。

1
2
3
4
5
6
7
8
9
10
11
12
13
@njit
def initialization(A, B, p_size):
A_h = np.size(A, 0)
A_w = np.size(A, 1)
B_h = np.size(B, 0)
B_w = np.size(B, 1)
p = p_size // 2
random_B_r = np.random.randint(p, B_h - p, [A_h, A_w])
random_B_c = np.random.randint(p, B_w - p, [A_h, A_w])
...

#报错
# Use of unsupported NumPy function 'numpy.size' or unsupported use of the function.

报错是说有不支持的NumPy函数numpy.size。实际上,不仅是numpy.size,Numba也不支持有三个参数的np.random.randint。为了解决此问题,和刚刚对numpy.nan_to_num的处理一样,最好是能用其他等价写法来代替不支持的函数。如果不行的话,则应该把不支持的运算和支持的运算分离开,只加速支持的那一部分。对于initialization(),我采用了第二种解决方法,把函数中耗时的循环拆开来单独用@njit装饰,其余有不支持的NumPy函数的部分就不用Numba优化了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@njit
def initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p):
for i in range(A_h):
for j in range(A_w):
x, y = random_B_r[i, j], random_B_c[i, j]
f[i, j, 0] = x
f[i, j, 1] = y
dist[i, j] = cal_distance(i, j, x, y, A_padding, B, p)


def initialization(A, B, A_h, A_w, B_h, B_w, p_size):
p = p_size // 2
random_B_r = np.random.randint(p, B_h - p, [A_h, A_w])
random_B_c = np.random.randint(p, B_w - p, [A_h, A_w])
A_padding = np.pad(A, ((p, p), (p, p), (0, 0)), mode='edge')
f = np.zeros([A_h, A_w, 2], dtype=np.int32)
dist = np.zeros([A_h, A_w])
initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p)
return f, dist, A_padding

另外的两个函数propagation()random_search()只会碰到取形状函数numpy.size的问题。这个问题很好解决,只要把numpy.size挪到函数调用外即可。

initialization_loop()propagation()random_search()都加上@njit后,程序的运行时间从33秒猛地降到了3秒左右。可以说,只用加@njit的方法的话,程序已经没有优化空间了。

用Numba提前编译加速PatchMatch

又看了看PatchMatch的源码,我发现,PatchMatch算法会先为每个像素随机生成一个匹配关系。然后,算法会迭代更新匹配关系。迭代得越久,匹配关系越准。而我之后要用PatchMatch处理一段视频,算所有帧对第1帧的匹配关系。那么,对于视频这种连续的图像序列,我能不能让第3帧初始化匹配关系时复用第2帧的匹配结果,第4帧复用第3帧的匹配关系,以此类推,以减少迭代次数呢?

说干就干。我准备先测试一下减少迭代次数后代码运行时间能缩短多少。迭代次数itr是在main函数里指定的,作者默认的数值是5。我把它改成1测试了一下。

1
2
3
4
5
6
7
8
9
10
11
if __name__ == "__main__":
img = np.array(Image.open("./cup_a.jpg"))
ref = np.array(Image.open("./cup_b.jpg"))
p = 3
itr = 1

start = time.time()
f = NNS(img, ref, p, itr)
end = time.time()
print(end - start)
reconstruction(f, img, ref)

结果,原来要花3秒的程序还是要花接近3秒,时间缩短得非常不明显。这不应该啊,理论上程序的运行时间应该大致和itr成正比啊。

测试了半天,我突然想起Numba文档里讲过,@njit是即时编译,函数的编译会在初次调用函数时完成。我每次运行程序时,大部分时间都花在了编译上,因此整个程序的运行时间几乎不由迭代次数决定。

我之后要反复运行PatchMatch程序,而不是通过运行一次程序来处理大批数据。即时编译的代价我是接受不了的。于是,我去文档里找到了Numba提前编译(AOT,ahead of time)的使用方法。

Numba AOT可以把Python函数编译进一个模块文件中。想在其他地方调用被编译的函数时,只需要import 模块名即可 。

官方给出的Numba AOT示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from numba.pycc import CC

cc = CC('my_module')

@cc.export('multf', 'f8(f8, f8)')
@cc.export('multi', 'i4(i4, i4)')
def mult(a, b):
return a * b

@cc.export('square', 'f8(f8)')
def square(a):
return a ** 2

if __name__ == "__main__":
cc.compile()

首先,程序要用一个模块名实例化一个CC。该模块名是未来我们import时用到的名称。之后,对于想编译的函数,我们要用@cc.export装饰它。@cc.export的第一个参数是调用时的函数名(原来的函数名会被舍弃),第二个参数用于指定函数返回值和参数的类型。做完所有这些准备后,使用cc.compile()即可完成编译。

运行该程序,会得到一个模块文件。根据平台的不同,该模块文件名可能是my_module.somy_module.pydmy_module.cpython-34m.so。不管文件名是什么,只要是在同一个文件夹下,我们就可以用下面的Python命令调用这个模块文件。

1
2
3
4
5
>>> import my_module
>>> my_module.multi(3, 4)
12
>>> my_module.square(1.414)
1.9993959999999997

用Numba做即时编译时,函数的返回值类型和参数类型可填可不填。而Numba提前编译中,必须要填入函数的返回值类型和参数类型。这让编写Numba提前编译的工作量大了不少,已经不像是在写Python,而是在写C了。

还有一点值得注意。和使用即时编译时一样,自定义的函数在调用其他自定义函数时,必须要加上@njit。所以,会出现一个函数即有@njit,又有@cc.export的情况。

学习使用Numba提前编译时,最主要是要学习Numba是怎么用字符串代表参数类型的。比如,i4是32位整型,u1是8位无符号整型,u1[:, :, :]是三维8位无符号整型,void是无返回值。这些表示可以在官方文档里找到。

以我写的Numba AOT PatchMatch的编译代码为例,我们可以看一看参数类型和返回值类型是怎么描述的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
from numba import njit
from numba.pycc import CC

cc = CC('patch_match_module')


@njit
@cc.export('cal_distance', 'f4(i4, i4, i4, i4, u1[:, :, :], u1[:, :, :], i4)')
def cal_distance(x, y, x2, y2, A_padding, B, p):
...


@njit
@cc.export(
'initialization_loop',
'void(u1[:, :, :], u1[:, :, :], i4[:, :, :], f4[:, :], i4, i4, i4[:, :], i4[:, :], i4)'
)
def initialization_loop(A_padding, B, f, dist, A_h, A_w, random_B_r,
random_B_c, p):
...


@njit
@cc.export(
'propagation',
'void(i4[:, :, :], i4, i4, i4, i4, f4[:, :], u1[:, :, :], u1[:, :, :], i4, b1)'
)
def propagation(f, x, y, A_h, A_w, dist, A_padding, B, p_size, is_odd):
...


@njit
@cc.export(
'random_search',
'void(i4[:, :, :], i4, i4, i4, i4, f4[:, :], u1[:, :, :], u1[:, :, :], i4, f4)'
)
def random_search(f, x, y, B_h, B_w, dist, A_padding, B, p_size, alpha=0.5):
...


if __name__ == "__main__":
cc.compile()

运行该程序后,在我的电脑上得到了名为patch_match_module.cp37-win_amd64.pyd的模块文件。可以在其他代码里通过import patch_match_module调用编译好的函数了,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import patch_match_module

def NNS(img, ref, p_size, itr):
A_h = np.size(img, 0)
A_w = np.size(img, 1)
B_h = np.size(ref, 0)
B_w = np.size(ref, 1)
f, dist, img_padding = initialization(img, ref, A_h, A_w, B_h, B_w, p_size)
for itr in range(1, itr + 1):
if itr % 2 == 0:
for i in range(A_h - 1, -1, -1):
for j in range(A_w - 1, -1, -1):
patch_match_module.propagation(f, i, j, A_h, A_w, dist,
img_padding, ref, p_size,
False)
patch_match_module.random_search(f, i, j, B_h, B_w, dist,
img_padding, ref, p_size,
0.5)
else:
for i in range(A_h):
for j in range(A_w):
patch_match_module.propagation(f, i, j, A_h, A_w, dist,
img_padding, ref, p_size,
True)
patch_match_module.random_search(f, i, j, B_h, B_w, dist,
img_padding, ref, p_size,
0.5)
return f

加上最后这步提前编译后,PatchMatch的运行时间从3秒降低到了0.6秒多。程序从最开始的180秒降到了0.6秒,几乎快了300倍。而且,如果是处理视频,还可以通过复用前一帧信息来减少迭代次数,进一步缩短每一帧的平均处理时间。能加速这么多,并不是我太强,而是Python实在太慢了。纯Python就不应该用来写科学计算程序。

总结

通过阅读这篇文章,相信大家能根据我这次Python PatchMatch性能优化经历,在不阅读Numba文档的前提下自然而然地学会Numba的用法。我把文章中提到的和Numba性能优化有关的知识点按使用顺序总结一下。

  1. 在面向应用的程序中,不要用Python写科学计算程序。哪怕要写,也要尽可能避免在循环中使用大量计算,而是去调用各个库的向量化计算。
  2. 直接在想优化的函数前加@njit装饰。在待优化函数里使用循环、NumPy函数都是很欢迎的。
  3. 如果碰到了Numba不支持的函数,可以通过两种方式解决:1)用等价的Numba支持的函数代替;2)把不支持和支持的部分分离,只加速支持的部分。
  4. 一个带@njit函数在调用另一个自定义的函数时,那个函数也得加上@njit。因此,应该自底向上地实现Numba即时编译函数。
  5. 如果你接受不了计时编译的编译时间,可以使用提前编译技术。使用提前编译时,主要的工作是给参数和返回值标上正确的类型。

Numba确实很容易上手,只要会加@njit,剩下碰到了什么问题去搜索一下就行。Numba的官方文档很详细,想深入学习的话直接看文档就行了。

本项目的代码仓库为:https://github.com/SingleZombie/Fast-Python-PatchMatch 。在原作者仓库的基础上,我添加了PatchMatch_numba_jit.pyPatchMatch_numba_compile.pyPatchMatch_numba_aot.py这三个文件。它们分别表示即时编译运行程序、提前编译编译程序、提前编译运行程序。

经历

学校申请在2023年1月31日截止,托福成绩要求100分。

我在两年前读本科时考过托福,阅读/听力/口语/写作的分数分别为26/25/21/22,共94分。在这两年里,我的英语听说读写水平均有提升。大致估计一下,我应该可以轻轻松松地多考6分,拿到100+的成绩。

带着这样的想法,我2022年11月才开始准备托福。准备之前,我先订下了复习的策略。由于我不知道现在的考试水平,只能以两年前的成绩为基础,思考各部分的提分策略。阅读、听力都是客观题,很容易自学,各提个2分非常简单。据很多人反映,口语很难靠自学提升,那我只要保持之前的21分即可。这样,要让总分超过100,这次考试最重要的就是写作了。对于看重英语读写教育的中国学生来说,英语写作的基础在高中时就学得差不多了。而针对托福考试的写作技巧,则可以快速学会。因此,我应该花时间集中攻关托福写作技巧。综上,我的复习策略总结为:稳步提高阅读与听力,保持口语,攻关写作。

11月,我一边正常工作,一边断断续续地做了阅读和听力的TPO真题,再次熟悉了这两部分的应试方法。同时,我也做了一两套口语题,熟悉了题型,保证听完题目后能张口答题。最后一周,我开始集中学习写作技巧。网上的托福「教程」全是打广告的垃圾信息。我从互联网垃圾堆里翻出了一篇论坛上的经验分享贴,下载了一本叫做《十天突破新托福写作》的书。这本书虽然成书于2010年,但其内容十分全面,毫不过时。我认真学习了一些能够短期内掌握的写作技巧。

慢吞吞地准备了一个月,差不多了。我把第一场托福家考的时间订在12月5日早上9:50。我不方便把考试时间订在我最清醒的晚上,又考虑到我有以前线下考试的经验,才把时间安排在了早上。另外,12月5日是星期一,在此之前我能有一整个周末的时间来做最后的准备。如果一切顺利的话,我可以在不影响正常工作的前提下于周一下午结束托福考试。

第一场考试即将开始。我知道自己准备得不够好,却也因此做好了多考几次的打算,只把这次考试当成了一次分数测试,没有什么心理压力。托福家考的考官会对用作考场的个人房间进行检查。和考官的几句英文对话,也恰到好处地缓解了我的紧张,令我的大脑迅速进入了英语环境。克服了刚开始动脑时打的几个哈欠后,我顺顺利利地完成了考试。

和普通的线下托福机考一样,家考结束后,只有客观题的阅读和听力会立刻出分。我迫不及待地点下查分按钮后,看到的是一个令我傻眼的分数:阅读27,听力19。

我曾经为查分结果做出多种假设,以决定下一次考试的时间。如果阅读和听力的分数特别好,我就等所有分数出了再看要不要继续考试;如果阅读和听力的分数较差,我就再学一阵,学得差不多了就考。无论如何,看到分数后,我的心情要么是欣喜,要么是失落。可是,看到这只有19分的听力,我先是撇嘴笑了笑,随后皱起了眉头。一阵不服气的火焰从我心中燃起。种种迹象表明,我的英语听力水平没有那么差。在这次听力考试时,我每篇材料也听得津津有味。怎么能考出这样的分数出来呢?这已经不是我有问题了,这绝对是因为托福出题方ETS出题不力,不能让题目反映应试者水平。空口无凭,下一次,我会用我的无懈可击的表现来证明托福听力的不合理。

我也不管这次的分数怎么样了,准备在最快的时间里开始下一次考试。根据最新的规定,3天之后就能考下一场托福。但是,在临考一周内报名要交额外的费用。于是,我决定在不缴纳晚考费的第一天,也就是下周一的同一时间,再战托福。在这一周,我会把我的托福听力水平提升到最强的境界。考完了,如果听力考出高分,我还要写一篇文章来总结托福听力经验,把托福听力狠狠地踩在地上。接下来一周的剧本,已经准备就绪。

自我上进并不是令人奋斗的理由。只有对自我或外界的不满而产生的复仇心理,才能催促人不断前进。我的拖延被立刻治好。第一场考试结束当晚,我就精神饱满地重新开始了托福听力的准备。本来我想把这整周都用在听力的练习上。可是,由于学习速度过快,练了三天,我就发现我的听力水平大有提升,以至于没有什么进步空间了。于是,最后几天,我还是去找了找口语和写作的状态,为第二次考试准备。

12月12日,周一,早上9:50,为了完成对托福听力的复仇,我又杀回来了。由于是第二次参加托福家考,一切安检的流程我都轻车熟路。十分钟内,考官就检查完了资料和考试环境,第二场考试开始了。

第一部分是阅读。我平时做阅读总是会错几题,尤其是最后的六选三大题。但是,我对此毫不在意。因为我知道,平时我阅读出错,完全是因为我不认真,太浪。如果换成真正的考试,我就会拿出全部的实力,认真记忆文章,根本不会错那么多题。事实也确实如此。在第二次阅读考试中,我已经适应了考场上做阅读的策略,做题做得心应手。72分钟,4篇文章,我只需要每篇文章都花18分钟即可。每篇文章我都留3分钟来解决最后的六选三大题。把所有题做对不成问题。

之后是听力。我的状态和上周差不多,略有不同的是,我记笔记的方法更好了,信心更足了。口语和综合写作的发挥也和上周差不多。最后一项任务是独立写作,其话题稍微有点难度:「政府要建拆旧建筑修新建筑,一批人要被迫搬到新的地方,你同意吗?」我半天没有找到特别通畅的写作思路,只好硬着头皮分个人、企业、政府三个方面扯了一堆反对的理由。

考试结束,又到了开奖时间。确认考试结束的画面我看都不看,狂点着下一步,迫不及待地查看起阅读和听力的分数:

阅读28,听力27。

多么令人惊喜的分数!这两部分的分数均创我个人新高,成功达成了「阅读、听力各高2分」的目标。另外,我前一周吹的牛也成真了。我确实考出了一个不错的听力成绩,有资格去写托福听力准备攻略了。再次确认分数没有看错后,我在房间里大笑不止。计划的顺利执行、复仇的完美实践、成功的如期而至、心腹大患已除的如释重负。种种喜讯,充斥着我的大脑,让我只能以大笑来释放那高悬已久的紧张心情。我一边笑,一边狂妄地叫道:「如果这个阅读和听力的分数都不能让我总分上100,我就去炸了ETS!」我觉得这次总分肯定有100分了。把两年前的口语阅读分搬过来,就已经有98分了。这两部分只要稍微提高一点,100分就有了。

两天后的早上,也就是第一场考试结束后的第九天,ETS寄来邮件,说第一场考试的成绩已出。我一看,阅读27,听力19,口语23,写作26,总分95。两年前,我的口语和写作分别是21分和22分。这样一看,我口语和写作的提升确实不小。把昨天刚考出来的阅读和听力在拼起来,总分就有104分了。相比上周的发挥,我有足足4分的退步空间。我觉得我这次考试上100分已经是板上钉钉的事情了。

我兴冲冲地和领导讲,托福已经考完了,我从这周开始可以正常干活了。用于做笔记的白板被我扔到房间一角,为了通过考试安检而清理的房间又被我弄乱,羞于让考官看到的二次元电脑桌面又被我换回。生活重回正轨,仿佛什么都没有发生过。我写了一篇托福听力准备攻略,以此为送别托福的赠礼。

放松了一周,又到了周一早上。总算,我不用再面对9:50开始的托福考试,可以睡个懒觉,再去学校。本来是轻松快乐的一天,不知怎地,我怎么都提不起劲,仿佛是背后被锁链扯住了。到晚上临睡时,不安感越来越强。我这才想起来,明天是考试后第八天。最快的情况下,明天就会出分了。在最终分数出来之前,一切还没有结束。

周二早上9点,我在睡醒后立刻点开了邮箱。果然,上次的托福分数已经出了。我连忙点进查分页面,第一眼看到了一个三位数的分数。我按捺住激动的心情,仔细一看,这个分数是104分。原来这是TOEFL MyBest Score,也就是各个小分的最优拼分结果。和我算得一样,拼分结果是104分。现在庆祝还为时尚早,我把眼睛往上挪了挪,看到了我最不想见到的结果——一个只有两位数的总分。定睛一看,上次考试的总分是99分!口语、写作都只考了22分,这导致最终的总分刚好离及格线100分差1分。

反复确认学校不接受拼分结果后,我的心凉了下来。无数的负面想法萦绕在我的脑中:「考前怎么不多准备一下口语和写作」、「为什么这么晚才考试」、「考试的时候状态好一点就好了」……。但是,现在,连多余的自责都是没有意义的。我要做的,只有尽力考好下一次托福。我也没有思考,照着上次的习惯,报考了下周二早上9:50的托福家考。一切从头开始。

这一周,我完全失去了活力。每天上午,楼下都在轰轰隆隆地装修。我晚上本来就睡得晚,白天多躺一下的权利也失去了。生活中又突然碰到了一些事,我不得不花了一两天去处理。好不容易到了周末,我才有心思去准备一下托福。可是,我一看到屏幕上的做题界面,就感到全身有无数根线在撕扯着我,让我快点结束做题。

截至周一,考试前一天,我这周几乎没有花时间准备考试,生物钟也乱得一塌糊涂。我想,最后一天了,准备考试也来不及了,就把生物钟调整一下吧。我找事情消磨着时间,希望自己不要在白天补觉了。下午,吃过午饭,我实在困得受不了,就靠在床头,心想:「就靠着休息一下吧。靠着睡睡得不舒服,睡一下应该就能醒来。」但当我醒来的时候,天已经黑了。我已经困到连靠着睡也能睡四个小时了。根据前几天的经验,我知道今天晚上睡不着了。

睡不了觉,那就用最后的时间准备一下考试吧。按我以前的复习习惯,考试前最后准备的应该是口语。然而,听完一道口语的题面,花完15秒的准备时间,当录音开始的一瞬间,我发现自己张嘴发不出声音。原来,在准备时间里,我的大脑一片空白,根本组织不出英语句子。既然如此,开口讲一些乱七八糟的话,也是浪费时间。巨大的焦虑,把开口讲英语也变成了一件困难的事。

什么事都做不下去,我只好关灯,上床,躺着,一秒一秒迎接末日的来临。在最应该睡觉的夜晚,我的大脑却无比清醒。我糊里糊涂地想了很多事情。最后也想开了,这次考试肯定是考不过了。今天是12月27日,离1月31日还早,还有机会准备。我爬起来想去取消这次考试,挽回一些考试费,却发现考试早就定下来不允许取消了。没办法,就当这次考试是调整心情吧。相比白白花了钱不去考试,还是参加一下比较好。

早上七点,天亮了。我总算困了,睡了下去。曾经慢慢吞吞令我倍感煎熬的时钟,此刻却又卯足了劲飞奔起来。没过多久,闹钟就响了,提醒我9:50时有一件非做不可的事。9:30,我艰难地起了床,准备迎接第三次考试。

我从来没有喝醉过酒,此刻,我却能深刻理解喝醉酒的痛苦。听说酒喝多了会呕吐,那么,止不住的呕吐,应该和止不住的哈欠是一样的吧。明知困得睁不开眼,却要在不断的哈欠中强打精神。这种折磨不是和明知胃里没有东西却不得不重复地呕吐一样吗?每次早上考托福时,我其实睡得都不太够。做一开始的阅读题时,精神又必须得完全集中。所以,每次考托福,我都会打一阵子哈欠,花一些时间让大脑适应。而今天,我的睡眠时间实在太短了,以至于哈欠不断,大脑久久不能集中。还是4篇阅读,72分钟,平均每篇18分钟。第一篇阅读我恰好花18分钟做完。做第二篇时,我被一道题卡住了,多花了足足6分钟。后面两篇文章我只好平均花15分钟飞快地做完。做完阅读,我就知道这次考试已经没救了。随后,我又勉强做完了听力题。到了口语考试时,我糟糕的状态就一发不可收拾了。和前一天晚上的表现一样,我在口语的准备时间里非常慌张,根本动不了脑子。等开始录音了,我才开始边想边讲,凭借本能让英语单词一个一个从嘴里蹦出来。最后,到了写作的时候,我的脑子总算清醒了。结果,独立写作来了道题面很长,题材很怪的题。我根本想不出这个话题可以从哪三个角度来讲,只好硬凑字数,第一次写了一篇只有两个主体段的文章。

考试结束,我漠然地查看客观题成绩:阅读25,听力22。「还好,至少听力比第一次的19分高。」反正这次考试肯定考不过100,我只能苦中作乐,怎么乐观怎么想了。确认考试结束后,我如释重负,往床上猛地一躺。

差不多是上周的这个时候,我确认了托福还要再考。当时,我曾用一张薄纸封住了内心的负面想法,只希望下次考试能够尽力。虽然我无时无刻不在被透过纸的针扎痛,但我一直扛到了现在。如今,已经没有任何克制的需要了。「为什么不早点准备?」、「为什么不好好准备?」、「为什么会差1分?」、「为什么这周状态那么差」、「为什么……?」、「为什么……?」、「为什么……?」、……。新账加上旧账,弯刀接着利剑。无尽的责问,让疲惫的我想躺在床上就此昏睡过去。可是,一想到没能好好睡觉也是失利的败因之一,我又立刻咬牙坐了起来。疲倦与饥饿接踵而至,我决定先去吃饭。

吃完饭,我去超市里买生活用品。没想到,几天不来,超市从白与绿的圣诞风格,变成了以红为主的春节风格。我这才发现,圣诞节已经在不经意间过去了。再这样考不过,新年也是同样的过法吧。在这一个多月,我除了准备托福和消磨时间,别的什么都没干。学习停了,工作停了,什么都不会做了。我就在不断重复着考托福,考不过,考托福,考不过。这样下去,一月底还考不过,导师可能就不要我了。后面的事我也不敢再想下去了。

吃饱喝足后,我冷静分析了一下现状。我认为我的水平没有问题,只是这周的状态太差了。第二次考试相比第一次,写作的分数低了很多。我还是应该保持其他成绩,全力提升写作。离申请截止还有四周,最多最多还有三次考试的机会。我振作起来,又报名了一周后,也就是下个周二的考试。说来奇怪,一考完,我的精神状态就正常了,当天晚上也睡得很香。过了两天,楼下的装修声也听不到了,我的生物钟算是调整回去了。我名义上还是在正常上班,工作日的时候敲了敲代码,生活的自信又找回来了。

正常准备了一周,转眼又到了考试前一天。按惯例,最后的时间我还是在练口语。上周的事情让我心有余悸,这次一定不能怠慢口语了。可是,无论我怎么练,说起话来都磕磕巴巴的,找不回第一次考试时那种飞速组织语言、毫无顾虑地发言的状态了。我花了很多时间,才练到勉勉强强能把回答说完。

1月3日,第四次考试。我虽然睡得还不是很充足,但算是回到了那种打几个哈欠就能清醒的状态了。然而,这次又出现了一点意外情况。之前的考试,我怕考官会禁止我使用纸巾,鼻子不舒服也没有拿纸来擤鼻涕。这次考试,我的鼻子特别不舒服。我拿纸快速地擤了一下鼻涕,发现考官并没有提出异议。我的鼻子好像接收到了这个信号,开始止不住地流鼻涕。结果,我隔一两分钟就要擦一次鼻涕。本来就困,呼吸还不通畅,我阅读做到一半突然就开始手脚发麻。还好我脑子一直是清醒的,除了中间略有卡壳外,大部分阅读题我都能确定地答出。做到听力时,我的鼻子总算正常了,考试开始平稳进行。口语考试时,我的状态也和考前一样,非常一般,只是恰好能够把题目答完。由于之前做过足够的练习,我做综合写作倒是十分顺利。可是,做独立写作时,又不太对劲了。题面是「乡村生活和城市生活哪个更好」。我想出了三个城市好的理由,却没太想好怎么论证,写了几个不太自然的示例。

考完,又是要提前看分。经历丰富的我面对查分已经没有了任何激动,只是平淡地一步一步点进了查分界面。这次,阅读28,听力21。阅读的发挥是符合预期的,但这个听力是怎么回事?怎么多睡了几个小时还没有只睡两小时做的听力分数高?

经历了太多太多,我已经麻木了。看到这种分数,我既没有开心,也没有难过。我只想问,接下来该怎么办?越花时间准备,听力考的分数越低。亏我还是写过托福听力攻略的人。反正花时间在准备考试上也没有提升,不如多玩一玩,那所以说我的懈怠也是合理的。托福考试嘛,其实就是看运气。包括两年前的那次在内,5次考试,我听力只有两次上了25分。我的听力水平其实很差。碰到运气好了,刚好听清了,分数就很高。多数(>50%)情况下,分数是很低的,能反映我真实水平的。写作也是,运气好,考官心情好,给了个26分。其实我水平没有那么高。我还能做什么呢?只能多考几次,撞撞运气了。

越想我就越不服气。为什么我总是习惯于责怪自己呢?就不能把失败归咎于外界吗?我不是曾经说过,托福听力有问题,不能反映考生的真实听力水平吗?今天的身体状态难道没有影响我考试吗?我没有问题,就是客观情况有问题。

想着想着,第一种声音又冒出来了:那为什么在半数以上的考试中,我的听力成绩那么差?这不就是菜吗?考试成绩就是一切。考得差了就是自己的问题。

自我斗争了半天,我决定用一种公平的方式来审判:在下次考试中,我将尽可能排除一切客观上的不利,用最佳状态来考试。如果还是考不好,那就是我英语水平有问题;如果考得好,过去的事情就一笔勾销,我可以光明正大地把考试失利甩锅给考试状态。这时,我才发现我之前有多么愚蠢。考试准备得用不用功另说,为什么我不把应试状态调整到最好呢?睡眠不好调整,但是考试时间是可以调整的啊。我可以早上晚一小时再考,然后喝一瓶红牛拉满状态。提高应试状态比努力准备考试要简单得多,却又有效得多啊!

我再次分析起了现在的学习情况。这两周我都没练过阅读,全是在考场上边考边练。在真枪实弹的练习中,我的阅读技术已经在不经意间练得炉火纯青。这次考试我在状态极差的情况下都考了28分就是证明。而下次考试我会状态拉满,岂不是随随便便就能考个满分?这样一看,阅读就不用练了。写作练了这么久,综合写作也没有提升空间了,独立写作练得也没效果了,维持在22分以上应该没问题。剩下要练的只有听力和口语。听力我曾经拿过27分,我只要练回当时的状态就行了。至于口语,从第三次考试开始,我考试时的心态就出现了问题。所以,这次口语准备的重点是做到不要紧张。只要口语考出正常水平,不比22分低,总分也够100了。一周的准备时间是足够的。我毅然订下了下周二早上10:50的托福家考。

这一周的头几天,我把所有准备时间都放在了听力上。和以前的练习方法略有不同,现在,我只在每天最清醒的时候,用最集中的精神去做题。一旦觉得困了,注意力不集中了,我就去躺一下,绝不让状态差成为借口,也免得错太多题影响心态。调整了状态后,我的练习正确率果然高了不少。

周五,听力练得差不多后,我开始练口语。正好,第三次考试的成绩也出了。由于元旦放假,成绩出得晚了几天。在我状态最差的这次考试中,我的阅读/听力/口语/写作成绩分别是25/22/17/22,总分86,创下了最低分的纪录。这只有17分的口语分数深深扎在我的心里,导致我在练口语时,越说越紧张,越练越不利索。我只能勉强安慰自己,哪怕状态再差,写作都能考22分。如果阅读和听力能接近满分,口语22分左右就可以了。现在只要把口语练到以前的水平就行,不要有太大压力,不用练得太好。就这样,最后三天我一直在练口语。

1月10日,周二,早上10点,我懒洋洋地从床上爬起来。吃完早饭,灌下大半瓶红牛,我精神饱满地坐到了电脑前。这是第五次考试了,是拼上个人名誉的一战,绝对不容有失。熟练地通过安检后,我又一次进入了阅读题的测试。一看到满屏的文字,我下意识地打了个哈欠。我立刻为这个哈欠担心了起来:「怎么回事?我还是没睡醒吗?红牛没有用吗?」不行,我不能在这里停下!我眼睛一瞪,聚精会神地读起文章来。还好,能量饮料是有用的。自此之后,我就再也没打过哈欠,状态神勇地横扫了前两篇阅读。做到第三篇阅读时,我再次被一道题卡住了。好在这次留的时间非常充足,哪怕是在第三篇文章浪费了点时间,我还是给第四篇文章留了完整的18分钟。最后,我完美利用时间,十分确定地答完了每一道题。

之后是听力。我主观感觉自己的发挥和之前几次都一样,听力材料听得很清楚,完全看题目里的信息我是不是恰好记住了。不过,客观上来看,我的发挥确实好了不少。有一道很恶心的题是问「下面这些行为是左脑还是右脑负责的」,等于一道题问了好几个小问题。由于我的笔记非常全面,轻松地把这题回答出来了。

中途休息后,继续进行口语考试。口语是我最担心的一部分。考试还没开始,我的心脏就砰砰跳个不停。第一道独立口语题一出,我就傻眼了:「有些人喜欢早上学习,晚上工作,有些人则相反。你更倾向于哪一种呢?」哪有这种题啊?我爱早上工作早上工作,爱晚上学习晚上学习,哪有什么理由啊?15秒的准备时间里,我愣是一个字的回答也没想出来。开始录音,我只好想到什么说什么,连珠炮一般地把想到的内容全说出来。说完一看,还有一半的时间,内容太少,太空洞了。我只好勉强补了一两句。这道题做得极其糟糕,我狼狈地进入了下一题。很巧,第二道题目一念,我惊讶地发现这道题是第一次还是第二次考试的原题。这整套口语题我都做过!我又是惊喜又是担忧。惊喜的是后面的材料我都记得,怎么回答我都有数了;担忧的是我的状态十分糟糕,之前答起来毫无困难的题,这次就答得这么费劲。但不管怎么样,作为一个讲武德的人,我还是当成第一次做,好好地按流程记笔记,按现场记下的笔记作答,没有在念题目的时候就开始凭记忆把回答准备好。后面这几道口语题答得都还正常。

又到了最后的写作部分。不像之前几部分的考试,写作的时间非常充裕。不管之前再怎么紧张,在写作时多数人都能冷静下来答题。我这次就很不巧,综合写作刚开始没多久,考试程序崩溃了。电脑跳转到桌面上,我突然发现我的工作应用忘了关,右下角有个图标在一闪一闪。我连忙点开应用,回了领导一句话。正当我准备关掉应用的时候,鼠标的控制权突然被考官抢了过去,重新点到了考试程序上。我瞬间惊恐起来,担心考官会认为我在作弊。不过考官也没说什么,考试很快又继续了。我就这样战战兢兢地做完了综合写作。独立写作,题目是「有人觉得大学应该收费,有些人觉得大学的学费应该由政府承担。你支持哪种观点?」这别说拿英文了,拿中文要我半小时内写一篇看上去合理的文章都很难。看到这个题目,第一想法是如果学费太高,穷人就上不起学了。可是,这只能支撑一个观点,字数一定写不够。我只好选择支持大学收费,让大学收费和高质量的教育关联起来,再分别谈高质量的教育对个人、学校、政府的好处,勉强凑够了分论点。

考试结束,第五次客观题开奖了。我感觉这次阅读做得很好,但听力还是有几道不确定的题。失败这么多次了,也习惯了,考得差就差了吧。于是,我不抱希望地点进了查分页面。接下来,我看到了两个惊人的分数:阅读30,听力28。喜悦感瞬间传遍了我的全身。要是几周之前,我肯定就立刻开心地跳起来了。可是,有了上次半场开香槟的经历,伴随着喜悦感的,是一种更沉重的恐惧感。无数的负面想法扑面而来:我的口语可是考过17分的,要是这次再考个17分怎么办?考官要是把我中途点开工作应用的行为当成作弊,取消考试成绩怎么办?就这样,看着刚考出来的高分,我压根不敢庆祝,只是呆呆地坐着。十分钟后,我突然感到全身乏力。我这才想起,现在已经是下午3点了。我已经非常饿,非常疲惫了,完全是靠能量饮料才撑到了现在。

接下来的几天,我没有再准备考试了。做其他事也提不起劲,于是我一有空就会来算分。第四次考试的分数是在周三,也就是考试后第八天出的。阅读/听力/口语/写作的分数分别是28/21/20/23,总分92。如果口语和写作的分数还是和上周一样,那总分就有101分,够了。可是,上周我口语还只是考了20分。这说明我的口语还是没找回状态。万一口语再考一次17分,分数不就不够了吗?就是这样,我用着各种方式去估计第五次考试的分数。从理性上思考,我上100分是很稳的。但是,有了第二次99分的经历,我再也不敢有乐观的想法了。

一边算着分,我还一边估计着下次出分的时间。第一次出分用了九天,第二次和第四次用了八天,第三次因为假期的原因用了十天。可见,八天或九天出分是最可能的。另外,每次通知出分的邮件都是8:05左右发的,意味着早上8点大概就可以看到分数。根据这些分析,我准备从下周三早上开始,每天早起看看有没有出分。

在周三之前,我的日常生活仿佛冻结了起来。我就像沙漏里的沙子,我的生活意义就只是单纯地见证时间的流逝。周三清晨,我一直睡得不踏实。到了早上7点,我在没有闹钟的情况下突然清醒了过来,死盯着邮箱的界面,等着8点到来。7点1分,……,8点1分,8点2分,……。等到了9点,分数还是没有出。我知道今天是不会出分了,如断了弦一般躺下并睡了过去。

又在紧张中度过了一整天后,周四清晨,我以同样的方式在7点的时候醒了过来。如我所料,8点,出分了。这次,我没有看错,精准地找到了本场考试的总分——103分。阅读/听力/口语/写作分别是30/28/22/23。如果阅读和听力的发挥稍差,和第二次考试一样,总分也有100分;如果口语和上次一样只有20,总分也有101分。可以说,我的托福是实打实地有100分的水平。哪怕有几门发挥较差,也不影响总分超过100分。反复确认了考试分数后,我总算是松了一口气。在我的「个人名誉审判」上,我总算证明了自己的实力,总算有资格把以前的失败归咎于状态。考完了,胜利了,说什么都是对的。

周六是1月21日,除夕。我有幸能快快乐乐地过一个春节。1月31日之前,我考了五次托福,总算考出了100分以上的成绩。整个考试准备过程可以说是非常狼狈。如果身边有别人说他考托福考了5次才考到100,我口头上或许会加油打气,但内心里肯定会把他嘲笑一番。但我还是想把这段不是那么光彩的经历分享出来。我认为这段经历可以帮助到很多人。

经验

虽然我在高考之后只经历过托福这一种大考,但我现在也很能理解参加其他考试(比如考研)的考生。没有老师天天监督指导,完全凭自己去找学习方向,这确实不是一件简单的事。很多情况下,有效的学习时间不多,调整学习状态(娱乐)的时间反而会更多。这些情况都是正常的,人人都是这样过来的。有害的不是不学习本身,而是自我指责的态度。负罪感会带来压力,不当的压力才会给正常学习带来阻碍。因此,准备考试时要克服的不是想要休息的心态,而是自己给自己过度施加的压力。不管怎么样,最后考完了,考好了,没人会管你过去做了什么。我复习考试的状态也非常糟糕,经常拖拖拉拉。最后考五次也勉强熬过来了。我觉得很多人的准备情况是比我好的,调整完了心态,肯定能够比我更顺利地考完考试。

自己准备考试时,最重要的是要能够自测,随时得知自己当前的水平。只有这样,你才能知道自己的复习有没有用,成绩有没有提升。另外,最好是能够把要学习的内容用进度表示,比如看多少门的内容,背多少单词。进度的完成能够鼓舞自己。最后,如果有条件,可以和准备同一门考试的同学一起学习。大家不要攀比学习进度,而只是一起分享学习的过程,聊聊天,分担一下压力。如果心态正常,能够随时检测能力是否提升,自学考试不是什么难事。

当然,托福考试有一点特殊。托福本身的学习内容不多,不像其他考试要花很多时间去学你以前从没有学过的知识。从这点来看,托福似乎可以快速地准备好。但是,托福的水平难以测试。托福有一半的内容是主观题,作为一个考生,你是不知道这些题目是怎么批改的。因此,你无法知道自己当前的水平,也难以知道自己的复习是否能提升成绩。没有可以用进度表示的内容,没有自测方式,很多时候你根本不知道怎么复习是有效的。准备托福就像做优化函数求不了导数的优化问题,正常算法根本行不通。

对于自学托福的考生,我还有更多的经验想要分享。不谈准备考试时的认真程度,如果我提前知道了这些经验,我的准备肯定会更加顺利。我非常想把这些经验分享出来。

自学托福的注意事项

  1. 不要在第一次考试前估计自己的托福水平。在我的备考经历中,我曾经用两年前的托福成绩以及这两年英语水平的提升来预测我现在的托福水平。事实证明,我高估了自己的提升,低估了考试的要求。如果没有天天讲英语的环境,口语和听力是不会大幅提升的。而如果又没有经常写英文论文的话,写作水平也不会有太大提升。因此,对于国内大多数理工科学生,只有阅读水平会在大学期间得到提升。另外,哪怕你英语各方面都学得很好,你也要花一些时间去学习应试技巧,把英语水平转换成托福成绩。总之,在实际考出一次成绩之前,你是难以估计自己的托福水平的。最好多留一点时间,在准备得差不多的时候就可以不带压力地先考一次试,看看自己的水平。
  2. 自学口语和写作是很难的。阅读和听力都是客观题,你可以靠对答案来不断反思提升。但是,如果完全不借助外界的帮助,你是难以评估自己的口语和写作水平的。无法评估自己的水平,自然也就难以找到提升的方向。因此,如果想完全自学上100分,一定要把主要时间花在阅读和听力上,这两门要尽可能做到满分。还有时间可以去练写作。至于口语,我看了很多托福考试经验,多数自学托福的人都说口语不好提升。我也建议口语练到22分左右就够了。大多数学校不会要求更高的口语分数。相对地,如果非得要报班学习,最好是请别人帮你提升口语水平,再是帮你批改作文。阅读和听力则完全没有报班学习的必要,全部可以自学。
  3. 注意考试状态。托福考试不仅是在考英语,还在考短期记忆力。在做阅读时,你要大致记住每段的主要内容,这样才能快速地答完最后一道六选三;做听力时记忆就更重要了,答题时八成靠记忆,两成考笔记。口语和写作倒还好,只要笔记记得好,不需要靠脑子来记忆。因此,考试时,一定要保持清醒,保持注意力集中,这样才能记得住东西。托福和大学里的考试不太一样,不是你会就是会,不会就是不会,你还得在有限的时间里把题目答完(我觉得托福甚至比GRE还吃状态)。这一点和高考很像。考前一定要好好调整作息,拿备战高考的状态来备战托福。

总结一下自学托福的整体策略。不管你的英语水平有多高,最好是至少留三个月时间来从容地考完托福。第一个月好好准备,主要是准备能够自测的阅读和听力。写作和口语熟悉题型,保证能回答出来即可。由于不能自测,提升写作和口语的效率是极低的。哪怕不花太多时间准备,也没有关系。准备得差不多了,就可以先考一次试,测试一下水平,熟悉一下考试流程。考完了,如果成绩还差一点,第二个月再开始根据考试成绩做进一步的提升。要考100分以上,最好是阅读听力接近满分,口语22左右,写作能考多高考多高。根据考试结果,哪一部分离目标远就去着重提升哪一部分。至少预留一个月来进行考试,不要让不充足的时间给你带来压力。如果准备时间不够,想要报班,那就只报班提升口语和写作。

接下来,我再详细谈一谈托福的四个部分分别该怎么准备。

托福听说读写的提升方法

基础

网上有很多人会讲自己的备考经验。但是,如果英语基础不同,准备的策略也不同。因此,我想定义一个开始准备托福的最低英语水平要求。这个水平是大多数人开始准备托福的水平,适用性较广。如果你还没有达到这个水平,可以先打基础,再开始针对托福进行准备。

要学习托福,听、读、写能力要达到高考接近满分的水平。此外,要记住大部分四六级单词。对于口语,要能够在提前准备的情况下对着PPT做英语课堂展示。

听力和写作的要求应该没什么争议。毕竟大多数大学生的听力和写作都是高考水平,再之前也记不起中学老师是怎么教听力和写作的了。

对于阅读水平,我认为除了有高考水平外,还需要多记一些单词。不必去背托福单词,背完大部分四六级单词就行了。四六级单词其实是英语里的常用词。如果你熟悉了四六级单词,那么你阅读日常的英语文章是不会有单词上的问题的。剩下的一两个不会的单词,你查一查,也就记住了。我不建议去死记硬背单词,尽可能通过阅读的方式自然记忆单词。

由于各地教育水平的不同,口语很难找到一个统一的基础。在我看来,开始练托福口语之前,只要会用英语表达意思就够了。哪怕用的单词很简单,或者某些单词不会,都没关系。这种口语的基础可以通过英语演讲水平反映出来。大学里的英语课是会要求做简单的演讲或对话的,如果你能在认真准备后完成课堂展示(不必是即兴的),英语的口语基础应该就没有问题。

阅读

要做好托福阅读,要从语法、单词、阅读方法、应试技巧这几方面学习。

英语的语法不难,高中学的语法够用了。哪怕什么定语从句、宾语从句的具体规则都不记得了,也没关系,看几篇文章就回忆起来了。问题是,托福的阅读文章中有不少长句。常常有句子刚说到一半,插入一个逗号,接个从句,又接回之前那句话。这导致我们读着读着就忘掉句子前半部分在讲什么了。刚接触托福阅读时,是会碰到这种句子语法结构分析不清楚的问题。这时,可以花点时间去分析一下这些长句。只要弄懂句子主干是什么、哪个从句是修饰哪个词即可,不用过分深究语法细节。多读几个这种句子后,语法上就不会有问题了。

如前文所述,开始做托福阅读之前,有四六级的单词基础即可。剩下的单词可以通过阅读自然地学习。除去四六级的常用词外,托福的阅读材料常常会包含一些专有名词、历史名词。碰到这种看不懂的名词是非常正常的一件事,千万不要慌张,知道它们是指代某个物体即可,不懂这些单词不影响上下文的理解。如果碰到了没背过的形容词,就去查一查,稍微记一下。另外,我发现考试里的阅读单词比TPO的更简单一点,基本上都是英语里的常用词。如果做TPO碰到了较多的生词,也不必太担心。

打好了语法和单词的基础,接下来最重要的就是阅读文字的方法了。如果不限时间的话,我们当然可以一个单词一个单词去翻译,一句话一句话去理解。可是,在考试中,时间是有限的,我们必须要学会如何快速阅读英语文章。我个人认为,想读好英语文章,要能够带着自信去读,就像读中文一样。看网上的中文资讯的时候,我们肯定不会逐字逐句去读,而是很快地扫两眼,哪怕内容没看全,也大概知道文章在讲什么。读英语文章也是一样。千万不能去先把单词翻译成中文再理解,或者一边念出来一边理解。要敢于看到几个单词,就去脑中拼凑这个句子的意思,想着怎样尽快把文章读完。保持这种习惯,我们看英语就能像看中文一样快速。我认为阅读方法是托福阅读中最重要的一环。我的很多阅读方法都是以前看讲GRE长难句的教材学到的。如果准备时间充足的话,可以去看看一些有关GRE阅读的书。学会那些方法后,看托福阅读就是降维打击。另外,如果你一两年后才需要考托福,可以平时就多看看本专业的论文。带着目的去看生活中的英文文章是提升阅读水平的最佳方法。

足够的英语阅读水平,并不能保证你在托福阅读中拿到高分。应试技巧,是把能力转换成分数的关键。托福的阅读题是有一些规律的。比如问「下面哪个句子能够解释文章中的句子」,那些错误的句子要么是漏了信息,要么是逻辑关系搞错了。再比如,六选三中,一些错误的选项会有「一定」、「所有」这种过于武断的描述。这些应试技巧培训班的老师应该会讲。但我觉得,这些技巧最好是自己通过大量的练习领悟,这样就不需要花精力去记忆了。另外,托福阅读题还有一个很重要的技巧:多数题都会问哪个选项描述正确。对于这些题目,你是能够逐个判断出每个选项是正确还是错误的。这样,一道题其实可以从两个角度来做:三个选项的内容文章都没提,所以它们错了;一个选项的内容文章提到了,所以它对了。如果从两个角度你都得到了同一个答案,那这题你肯定做对了。我考试中做阅读题卡住的几次,都是排除了两个选项,觉得一个选项很对,但最后那个选项没找出它错在哪里。最后一次考试出分前我之所以觉得阅读做得很好,也是因为每道题我都确认了两遍。

总结一下,托福阅读只要无脑做TPO即可。在这个过程中,你会自然地学会分析长句,学会如何跳过不会的名词。为了加快阅读速度,你可以有意识地练习一下英语速读。多做一些题目,多找一下题目中的规律,你差不多能够凭借自己学会应试技巧。练习时,别忘了练时间分配。一开始练习时可以先在完全没有时间限制的情况下尽可能提高正确率,之后再在限定的时间内尽可能完成答题。完成所有这些的练习,就足以在托福阅读中拿到高分了。对了,考试前一定要睡醒。六选三其实就是考察你记没记住文章里的内容,要去文章里确认每个描述是否被提及。如果状态不好的话,你可能就记不住文章的内容,做六选三会非常吃力。

最后再分享一下我阅读文章的节奏。托福阅读题是顺着文章内容出的。比如前两题是和第一段相关,之后的两题是和第二段相关,只有最后一道六选三是和全文相关。因此,我阅读文章时,会先阅读一整段,再看这一段的题,题目问到哪我就再把哪看一遍。做完这一段相关的题,再去读下一段。我以前也试过先读全文再去做题,但我发现阅读全文时我对文章的理解很不到位。而一边做题一边读的话,能深刻理解每一段的意思。做完除最后一题外的所有题,也差不多懂了全文的意思。

听力

我之前已经写了一篇托福听力的准备方法了。我觉得这篇文章写得很好,方法非常系统,总结非常到位。可惜,可能文章写得太花里胡哨了,似乎看的人不多。我这里再稍微总结一下那篇文章的意思。

托福听力分三部分去准备:听力能力、记忆能力、理解能力。如果听不懂文章的内容,就去用精听等方法提升翻译英语语音的能力;如果听完了总记不住,就去练习应试技巧,尝试提前预测出题点;明明听懂了题目却做不对的情况较少,如果有,多反思一下做题时的想法,看看是不是做题做得太快了,选项没读懂。练习英语听力时,先想办法把错因归为上述三类中的某一类,再集中练习。

口语

我完全理解不了口语是怎么批改的,也没资格介绍口语经验。非得说我从五次口语成绩中分析出了什么,我只能说,越紧张,口语分越低。我最自信的第一次考试口语的分数反而最高。想口语拿高分,就不要自学,寻找外界的帮助吧。

如果给我充足的时间,我会按这个顺序准备口语:一开始练独立口语,学习一下独立口语怎么想内容(比如一个话题展开两到三个角度,或者观点正着说一遍反着说一遍),保证这题有话可说。之后练后面三道口语题,熟悉题型,知道大概的套路。再之后提升后面三道题的流利度,学习高分回答的答题方式,如一般要说几句话、材料的内容要复述得多准确、如何用连接词过度。最后还有时间,再去练习独立口语,准备一些常用的段子,尽可能减少说不出话的情况。离考试越近,越要多练口语,让自己在考试的时候充满信心。

写作

综合写作比较简单。综合写作会先给一段材料,材料中有一个观点的三个分论点。用三分钟读完材料后,会有一段听力,逐个反驳材料中的三个分论点。听完听力,在可以看到材料的情况下复述听力内容。相比听力测试,这段语音的理解难度很低。演讲人语速很慢,且关键处会反复重复。因此,在做听力笔记时,完全可以把演讲者观点中的关键词全部记下。写作时,只要把关键词串成句子即可。平时练习个一两次,学一学套路,掌握个简单的模板,知道怎么用连接词把句子串起来,考试的时候综合写作做起来就会非常轻松。

多数人讨论托福阅读都是在讨论独立写作。独立写作就完全是根据一个话题自主写作了。独立写作有两大要点:文章内容如何组织、英语用法是否地道。

文章的组织方式可以在短期内学会。准备写作时,尤其是在备考时间不够时,要把重点放在这上面。托福独立写作常常会给一个大家都听过,但不是那么好写的话题。想要漂漂亮亮地在半小时内论证这种话题,哪怕是拿中文,都是不太可能的。不过还好,托福写作考察的重点不是内容是否合理,而是英语表达能力是否过关。因此,考前要多花时间学习展开话题的技巧,知道怎么用恰到好处的废话来完成一篇英语写作。

组织文章内容,又需要考虑三个方面:文章整体结构、每段的分论点、段内句子展开方式。

文章整体结构,指的是你是无条件支持或反对题目的观点,还是一半肯定还是一半否定,还是其他什么形式。有非常多的托福写作教程会讲文章整体结构的分类。我反正只会一种叫做「一边倒」的方法,无条件支持或反对题目的观点,然后每一段都强烈地支持我的观点。这种方式最简单粗暴。其他常用的还有两段肯定,最后一个主体段委婉地让步一下。建议大家选择某一种组织方式,平时练习和考试都只用这种方式。

想每段的分论点,即怎么把话题展开,分成多个方面讨论,让自己有话可说。就比如我正在写的这篇托福经验分享,在介绍托福各个部分的准备方法时,我会讲这个部分要从哪几个方面学习,再详细谈每个方面具体怎么做。问题是,平时写文章,我都是有了要表达的东西,再用清楚的逻辑去写作。而托福写作是逼着你去对着一个话题,在三十分钟内组织一篇空洞无比的文章。我都没有想写的内容,只能生搬硬凑了。所以在我看来,想托福写作的分论点,就是要大开脑洞,能怎么扯怎么扯,只要是和主题相关的分论点即可。各个分论点之间是否有着密切的逻辑关系?没人会在意这个。不同题材,有不同组织分论点的方式。我多数情况是对着题目临场想分论点。我唯一掌握的一种模板,就是分个人、组织(公司或大学)、政府分别讨论。建议平时练习独立写作时,可以多看几个题目,不用写作,就想一想分论点。之后再看看范文的分论点,学一学别人是怎么组织的。

想好了文章结构和分论点,还要知道每一段的内容具体该怎么写。其中,开头结尾的组织方式和主体段的组织方式不同。开头结尾是有套路的,可以提前准备。比如开头先复述一下材料,讲一讲多数人的观点,最后坚决地提出自己的观点;结尾的时候总结一下每个主体段,再重申自己的观点。这两段不用写得很花哨,不是很难。最难的是各个主体段的写法。主体段一般是三个,每段你要用将近100词来详细阐述分论点。没有足够的备考经验的话,很容易碰到主体段没东西可写的情况。我认为,靠自己反复写作是提升不了写主体段的能力的。最好多参考别人的范文和一些教辅书,学习一下段内的叙述方式。一般来说,一个主体段,可以先提出观点,再用一两句话去详细解释,最后用示例丰富观点。句子之间多用连接词,多构建因果、递进、比较等关系。示例不用像高考语文一样非得用名人示例,也可以用生活中的普遍情况或者自己的例子。介绍这种写作技巧的书和资料非常多,建议大家去找一下,一定要学习别人的经验,不要闷头反复写文章。

组织段内内容时,有一些赖皮的方法。比如,可以一句话,同样的意思,颠来倒去地说。再比如,可以准备一些万能的句子,放到文章的开头结尾,或者某一主体段中。甚至文章话题不难的话,你可以把提前背好的一整段都搬过来。

举个例子。 “xxx plays a crucial role in our daily life. That is to say, xxx is important to us”。我见过诸如此类看上去华丽,但全是废话的句子。没办法,这种优美的句子确实好使,只要你不是全文都用这种废话就没事。

我对这些方法不做评价。反正托福是一个功利性很强的考试,能考高分怎么做都是对的。如果你时间有多,可以去背一背做一点准备。

除了文章的内容,剩下要考虑的就是如何让表达更加地道。比如,“孩子是祖国的花朵”显然就不是英语里会出现的说法。妥当的语言表达是一件很难在短期内学会的事。哪怕是中文,由于互联网流行词几个月一换,几年前的文章现在看来也不是那么时尚。没有语言环境是难以真正地提高语言表达的。因此,我推荐自学的考生完全放弃提升语言表达。

当然,我最近才发现一种效率较高的自学语言表达的方式:利用ChatGPT。你把自己写的文章输入进ChatGPT,让它帮你润色,它能够在完全不改原意的情况下让文字观感大幅提升。这等于是有一个教师手把手带着你练写作,告诉你哪里怎么改,怎么提升。ChatGPT的出现完全解决了英语写作无法自测的问题。我认为,如果使用得当的话,ChatGPT完全可以代替写作老师,让学生自己就找到提升写作的方向。很可惜,ChatGPT出现时我还在闷头准备托福,没来得及深入探究ChatGPT的用法,也没有使用ChatGPT进行写作练习的测试结果。我大胆预测,多数人用ChatGPT提升自己的英语表达后,托福写作至少能高2分。

以上内容主要是我从别处学到的托福写作准备方法,可以保证这些准备的方向是有效的。接下来我再稍微谈一下我自己的托福写作考试经验,并对托福写作的准备方法做一个总结。我也很纳闷,不知道为什么第一次考试写作拿了26分,后面就再也拿不到了。现在想来,估计是因为那次的题目比较简单,我的段内表达都比较流畅。后面每次独立写作时我都感觉很不顺利。在这几次考试中,我都举了和自己相关的例子。这些例子太假了,完全没有说服力。因此,我建议按以下顺序准备托福写作:尽快选择一种文章组织方式,比如最简单的「一边倒」式,并稍微学一下开头结尾的套路。之后花主要时间学习如何想分论点和组织段内句子,要做到有话可说,能够流畅自然地写完三个主体段。最后时间有多的话,可以借助ChatGPT提高语言表达能力。顺带一提,我每次独立写作就是300多字。只要够300字,字数多少并不影响成绩。当然,你有话可说,可以写出很多内容,那再好不过了。

总结

想自学托福上100分,建议写作和听力拉满,尽力学写作,口语差不多即可。

写作对着TPO无脑练就行。除了对答案提升解题能力外,着重提升快速阅读英语文本的能力。

听力也是对着TPO练。不过,要先分析自己的弱项,不要上来就去精听。在反复的自测中,逐个提升听懂、记住、写对的能力。

综合写作练一两次熟悉题目即可,稍微准备一个简单的全文结构模板。独立写作先学文章整体结构和开头结尾的技巧,再着重学习提论点和填充段内内容的方法。可以借助ChatGPT批改作文。

听力写把独立口语练到能说的程度,再熟悉三道综合口语的题型。之后主要提升综合口语,有时间多再认真准备独立口语。说口语时一定不能紧张。

搜索引擎根本搜不出有用的托福攻略。建议上留学论坛找前人的经验分享,这些帖子里经常能够找到许多学习资源(教材、作文范文)。四个部分中,写作的技巧是最需要向别人学习的。

在这篇文章中,我动之以情,晓之以理,把我在五次考托福中学到的东西全讲出来了。从今天起,我的脑中再也不想出现「托福」这两个字了。希望大家阅读后有所收获,早日战胜托福考试。

最近,我在考托福。第一次考完,我惊讶地发现我的听力只有19分(满分30)。我两年前考的成绩都不止这么一点。这两年我也一直在接触英语,英语水平不可能退步。如果托福真的是一个合理的,能够反映考生真实水平的考试,那我不可能考出差别这么大的分数。因此,我认为,考出这么差的分数,不是我的问题,是托福考试的问题。托福考试过分强调应试技巧,而无法稳定地反映我的水平。为了证明我的观点,我不服气地宣誓道:「我要用一周时间,学习听力应试技巧,考出一个较高的听力分数。」一周后,我又考了一次托福。结果很喜人,我的听力考了27分。

在这几天里,我总结了网上的托福听力准备方法,用一套规范的算法流程把方法表达了出来。我将分享一下这一套和深度学习算法形式相似的托福听力准备方法。我会从头把方法的背景、解决方法讲清楚。哪怕你不需要准备托福考试,或者说对深度学习没那么了解,也可以把这篇文章当故事读一遍。

托福听力规则

托福听力主要涉及两类材料:对话与讲座。对话通常发生在学生与教授或学校工作人员之间,描述了校园中常见的一些讨论、询问。讲座则模拟了真实的课堂教学,讲师会对某一专业话题做简要的描述,偶尔会穿插几句学生的提问。讲座涉及的知识面很广,常常会谈及艺术、历史、生物学、地理学、心理学等领域。不过,这些讲座不会讲特别深入的内容,也不会讲过于偏离生活的概念,保证多数人都能听懂讲座的内容。

一场无加试的托福听力由两轮组成。每轮会听1段对话和1~2段讲座。对话有5题,讲座有6题。

托福听力的答题形式和国内多数英语考试不同。每段听力中,考生只有听完了听力材料后,才能看到题目。并且,只有确认提交了前一题才能答下一题。当然,可以在听的同时记笔记。

题目全是选择题。每段听力材料给3~4分钟,基本不会有时间不够的情况出现。

托福这种不允许提前看题的考试模式把考生的记忆力也变成了考察目标之一,为考试增添了不合理的难度。后续算法的诸多改进都是为了解决「记不住」这一问题。

托福应试流程

大多数人在初次接触托福听力时都会采用这一套非常直观的算法:

但是,这套算法有一个问题:无论我们的听力水平多么优秀,都不可能把材料原原本本地记忆下来。一旦有题目考察了一个我们没记住的地方,这道题就答不上来了。

因此,多数托福教程会给出一套改进的算法。这套算法把「听材料」细分成了两部分。第一部分是「语音转文字」,这是我们大脑自己训练出来的功能。理解了听力材料的内容后,我们根据一些先验知识(知道听力材料的哪些部分容易出考题),对部分内容做重点记忆。重点记忆的方法可以是竖起耳朵集中注意力,也可以是用笔记记录关键词。最后,根据重点记忆的内容和大脑中残留的其他记忆,回答问题。

这套算法出色地把托福听力分成了三个独立的子任务:语音识别、记忆、阅读理解。这三个子任务恰好对应了三种要考察的能力:听力能力、记忆能力、理解能力。把这三种能力拆开来讨论是很有必要的。如果你没有意识到自己哪种能力相对相差,盲目地做题,尝试同时提升这三项能力,那你的学习是十分低效的。后文我们将基于这一套算法讨论如何分别提升这三种能力。

错因分析

在正式开始学习之前,我们一定要诊断出自己哪方面的能力有所缺失,进而对症下药,从最差的一项开始练习。为了找出做听力的问题,我们要进行错因分析。

为了考察听力、记忆、理解这三种能力,我们固然可以用做托福听力题以外的方式去分别测试这三种能力。但是,使用其他测试方式的话,我们不能保证我们测试出来的能力恰好是托福考试要求的能力。比如,你可以拿托福阅读来测试自己的理解能力。但是,由于托福阅读的理解难度比听力的理解难度要大,哪怕你阅读做得不好,也不能说明你听力就理解不好。因此,做托福听力题这件事本身就是最好的测试方式。

可是,正如前文所述,做托福听力会同时考察三种能力。该怎么分别考察这三种能力呢?其实,使用一些巧妙的控制变量法,就可以把这些能力区分出来了。我把网上提出的各种错因分析方法加以总结融合,提出了一种「反掩码错因分析」法。

什么是「反掩码」呢?众所周知,掩码 (mask),指的是计算机中用于屏蔽其他数据的一种数据,有时也泛指屏蔽数据这一过程。那么,我提出的「反掩码」,指的就是把原本不透明的数据变得透明。

在托福听力中,听力材料是不透明的。你需要通过自己的听力能力把听力材料变成可理解的文字。如果直接把听力材料变成了阅读材料,你就可以直接根据听力原文答题。这样,做题考察的就只有理解能力,而不再考察听力和记忆能力了。

因此,为了考察自己的理解能力。可以先听一遍材料,做题,不看答案,读一遍原文,再做题。第一遍做题是正常的练习,第二次做题是控制变量。如果第二次做题还是错了很多,就说明理解能力不行;如果第一遍做题相较第二遍错了很多,就说明听力和记忆存在问题。

同理,为了进一步区分听力和记忆的问题,我们可以调整反掩码,使用更精妙的控制变量方式:先听一遍材料,做题,不看答案,再听一遍材料,再做题。在第二次听力的时候,你已经知道题目了,不存在记不住的问题。如果第二遍答题还是错了很多,就说明听力和理解有问题;如果第一遍相较第二遍错了很多,就说明记忆有问题。当然,如果你的理解已经基本没问题了,在这一步就可以排除掉理解能力的影响,直接区分听力问题和记忆问题。

综上,「反掩码错因分析」的流程如下:

  1. 找出几套听力题。先听一遍材料,做题,不看答案,读一遍原文,再做题。主要分析自己的理解是否存在问题。
  2. 理解能力最容易提升。先想办法提升理解能力。
  3. 再找出几套听力题。先听一遍材料,做题,不看答案,再听一遍材料,再做题。区分听力问题和记忆问题。
  4. 根据诊断结果,先后提升听力问题和记忆问题。

能力提升

诊断出了问题后。就应该考虑如何设计子任务分别训练这三种能力了。

听力能力

使用听力能力时,输入是英语音频,输出是大脑中可理解的英文文字。这一过程完全由大脑的本能决定,几乎不需要主观思考。因此,我们可以设计一个非常直接的子任务:使用任意一种英语音频,一边听,一边「输出」自己的听到的文字。听完音频后,比较自己的输出和原文,看看自己哪些地方没听懂。不断反思,让大脑自己提升。

网上的很多托福听力攻略会叫你去「精听」,「听写」,或者使用什么「影子跟读法」。其实,所有这些方法都只是为了提升听力水平。他们的目的都是构造一个输入音频输出文字的训练任务。只是「输出」的方式不同而已。在我看来,输出文字的方式,可以是复述,可以是听写,甚至不用讲出来,在大脑里有个印象就行。所有这些具体的方法里,我不推荐听写法,因为你大量的学习时间都会浪费在写单词上。听完一句话后,复述这句话即可。至于是一句话反复听,还是听完一整段材料,这些形式都不重要。保证你在强迫自己不断输出听到的内容即可。

另外,与其纠结听力练习的具体方法,不如去花一点时间准备恰当的听力材料。对于有高考英语听力水平的人,直接拿托福官方题目(TPO)的听力材料练就可以,类似于TED的知识分享、新闻播报也可以。看电视剧的提升可能没那么快,因为电视剧的语速较快,且通常只有日常用语,与托福听力材料的内容不符。我个人推荐去上自己专业的英文公开课,比如大名鼎鼎的「MIT线性代数」。这些公开课语速适中,用语朴素,比较容易听懂,形式与托福听力类似,还可以顺便学一下专业知识。

记忆能力

托福听力考试的记忆分两类:第一类是被动记忆,也就是你听完整段材料后自然残留在大脑里的记忆;另一类是主动记忆,是你根据以往的答题经验,对材料中你认为的重点段落的记忆(或者笔记)。被动记忆我们没法操控,只能祈求考试当天头脑清醒一点。因此,练习记忆时,主要是练习托福听力应试技巧,提高对出题点的灵敏度,并且做到在不影响听力的情况下记下笔记。

在培训班或者网上,都能找到大量的托福听力技巧。这些技巧告诉你什么地方容易出题,听到哪些词的时候做笔记之类的。但是,我觉得背技巧的效率是很低的。最好的方式是,自己去尝试发现技巧,有了一定的经验后再去和别人的技巧对比。

那么,怎么去学会洞察听力材料的出题点呢?有两个方面的事要做。一方面,做完了题目后,总结归纳常见题型,并且找题目对应的原文,总结规律;另一方面,在听材料时,多做一点笔记,做题时看看哪些笔记用到了,哪些没有。甚至,对于以前练过的材料,可以直接跳过听力,读原文,猜出题点。通过正向和反向的练习,很快就能掌握技巧。

作为示例,我来分享几个我发现的技巧。

  • 对话材料90%会问对话发起的原因是什么。注意,这道题不是问整段对话的主题是什么,而是问学生为什么找老师或老师为什么找学生。很可能两个人寒暄了老半天,学生突然支支吾吾地说出自己过来的原因,然后话题一下就飞走了。如果你没有预判出题点,很可能这一句话就被你忽略了。
  • 讲座有时会先讲理论,再讲示例。题目会问这个示例是揭示哪个理论。如果示例里恰好冒出两个专有名词,你又不知道这里会出题,这个例子就会被你忘掉了。因此,听到示例,赶快把示例里涉及的几个名词记下来。
  • 讲座中,老师和学生都可能会表达对于某一理论的态度。尤其是结尾,老师很可能冷不丁地说「虽然这个理论很有名,但我并不是很认可」。题目很喜欢考这些态度。你可能辛辛苦苦听了4分钟的材料,想着听力要结束了,可以放松一下了。结果最后这两句很重要的态度就被你忽略了。

找到材料中的出题点,其实只是比较初级的技巧。如果你想成为高手,还有一种高阶的,更一般的方法。我把它称之为「基于注意力的记忆法」。听力材料,甚至是阅读材料,以及我们生活中各种各样的文字讯息,都是有很多废话的。如果你知道哪些话言之无物,你就可以过滤掉这些话,而集中记忆那些有意义的话。更进一步,如果你知道听力中哪些话一定不会出题,那么这些话也可以过滤掉。你可以只把注意力集中在哪些容易出题的有意义的话上。

比如说,老师讲「好久不见」、「我们上节课讲了什么什么,但是这节课……」。这些话都可以直接当成废话忽略掉。讲座时提到某个名人出生于何处,几几年在哪上学,这种过于细节的地方也不可能出题,可以忽略掉。

这种基于注意力的记忆法是很有用的。一来,它可以帮你找到重点和非重点的话,有助于合理安排精力;二来,在知道这句话不重要时,可以放心地去记上一句话的笔记。

这种记忆法不用刻意训练,多听了几段材料后就能自然地知道听力材料中哪些地方是可以忽略的。

理解能力

托福听力中的理解,既包括对材料的理解,也包含对题目和选项的理解。整体来说,托福听力对理解的要求不高。只要清楚地听懂了材料,做题时一般不会因为理解而扣分。

材料理解比较简单。只有介绍一个复杂的理科概念,或者人物表明态度时过于委婉,才可能导致材料理解错误。只要把托福阅读题练好,听力材料不可能有理解问题。

理解题目、选项时则可能会碰到一些问题。比如在主旨题中,每一个选项概括主旨都概括得不是很确切,但是其中三个选项有明显错误,一个选项概括得不全。这个时候,得选择那个概括得不全的选项。为了解决这个问题,只需要多做点题,总结记录选项理解错误导致的错题,很快避开常见的一些坑。另外,听力的时间非常充足,碰到模棱两可的选项时可以多读两遍题,千万不要把题目读错。

综合练习

TPO 30之前的题都比较简单,可以拿这些题目来反复训练各个子能力。预训练好了大脑中的各个子模块后,就可以直接开始做TPO 30之后的题,进一步综合训练所有能力。

综合练习时,应按照正常的应试流程,一边听材料,一边做笔记。听完了材料就做题,对答案。有题目做错了,大致把错因归个类,看看是三步中哪一步出了问题。根据错题的情况做进一步的查缺补漏。

在评估自己的综合练习水平时,除了对答案,还需要反思两件事:材料中的哪几句话没有听清;笔记是否记下了重要信息。

很多时候,你可能听漏了某几句话,但依然把题目都做对了。但这样并不保险,最好是把每一句话都理解透来,保证听力能力没有问题。

哪怕你已经熟悉了材料中的出题点,依然需要练习一下记笔记的方法。首先,你要保证记笔记的时机合理,不能影响听力。其次,你要保证重要的信息没有记漏。在不影响听力的前提下,多记笔记是没有坏处的。因此,反思时,只需要评估哪些题目中考察的信息是漏记的,尤其是哪些导致你做题出错的漏记的信息。

最后,再总结一下各种能力在托福听力中的重要性。理解能力是必须的,也是难度最低的。一定要把理解能力练到满分。听力能力是答题的基础,是托福听力题的主要考察对象。听力能力是最考验基本功,最难在短期提升的,一定要早做准备。记忆能力是把听力能力转换成分数的必备途径,你可能文章听得津津有味,答题时却头脑发昏。记忆能力最好提升,几天内就可以总结套路,掌握记忆关键信息的方法。

方法总结

我们来把这篇文章讲过的托福听力准备方法和应试方法整理一下。

准备考试:

  1. 做几套TPO听力,使用分析错因,找出自己较弱的能力。分析错因时,先听一遍,做题,然后在不看答案的前提下再听一遍或看一遍原文,排除出做题出错的原因。
  2. 从较弱的能力开始,逐个提升提升能力。
  3. 练得差不多了以后,综合练习,查缺补漏,同时练习应试的状态与技巧。

考试时:

  1. 听语音,在脑中把语音变成文字,判断这句话有没有重要信息。
  2. 如果觉得这句话很重要,就着重记在脑子里,或者用笔记记下。对话可以少记或不记笔记。讲座必须记下关键信息,比如举例、分类讨论、态度之类的。
  3. 根据笔记或印象答题。

托福听力题本质上还是在考察听力能力。如果你听力的基础不好,还是要把主要的时间花在打基础上。等听力水平差不多了,可以按照这篇文章介绍的内容,学习一些技巧,把听力能力转换成分数。希望大家能够考出一个不错的托福分数。

附录:托福听力框架图