在之前的文章中,我介绍了如何用NumPy实现卷积正向传播。
在这篇文章里,我会继续介绍如何用NumPy复现二维卷积的反向传播,并用PyTorch来验证结果的正确性。通过阅读这篇文章,大家不仅能进一步理解卷积的实现原理,更能领悟到一般算子的反向传播实现是怎么推导、编写出来的。
项目网址:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/BasicCNN
本文代码在dldemos/BasicCNN/np_conv_backward.py
这个文件里。
实现思路
回忆一下,在正向传播中,我们是这样做卷积运算的:
1 | for i_h in range(h_o): |
我们遍历输出图像的每一个位置,选择该位置对应的输入图像切片和卷积核,做一遍乘法,再加上bias。
其实,一轮运算写成数学公式的话,就是一个线性函数y=wx+b
。对w, x, b
求导非常简单:
1 | dw_i = x * dy |
在反向传播中,我们只需要遍历所有这样的线性运算,计算这轮运算对各参数的导数的贡献即可。最后,累加所有的贡献,就能得到各参数的导数。当然,在用代码实现这段逻辑时,可以不用最后再把所有贡献加起来,而是一算出来就加上。
1 | dw += x * dy |
这里要稍微补充一点。在前向传播的实现中,我加入了dilation, groups
这两个参数。为了简化反向传播的实现代码,只展示反向传播中最精华的部分,我在这份卷积实现中没有使用这两个参数。
代码实现
在开始实现反向传播之前,我们先思考一个问题:反向传播的函数应该有哪些参数?从数学上来讲,反向传播和正向传播的参数是相反的。设正向传播的输入是A_prev, W, b
(输入图像、卷积核组、偏差),则应该输出Z
(输出图像)。那么,在反向传播中,应该输入dZ
,输出dA_prev, dW, db
。可是,在写代码时,我们还需要一些其他的输入参数。
我的反向传播函数的函数定义如下:
1 | def conv2d_backward(dZ: np.ndarray, cache: Dict[str, np.ndarray], stride: int, |
虽然我这里把所有参数都写在了一起,但从逻辑上来看,这些参数应该分成三个类别。在编程框架中,这三类参数会储存在不同的地方。
dZ
: 反向传播函数真正的输入。cache
: 正向传播中的一些中间变量Z, W, b
。由于我们必须在一个独立的函数里完成反向传播,这些中间变量得以输入参数的形式供函数访问。stride, padding
: 这两个参数是卷积的属性。如果卷积层是用一个类表示的话,这些参数应该放在类属性里,而不应该放在反向传播的输入里。
给定这三类参数,就足以完成反向传播计算了。下面我来介绍conv2d_backward
的具体实现。
首先,获取cache
中的参数,并且新建储存梯度的张量。
1 | W = cache['W'] |
之后,为了实现填充操作,我们要把A_prev
和dA_prev
都填充一下。注意,算完了所有梯度后,别忘了要重新把dA_prev
从dA_prev_pad
里抠出来。
1 | A_prev_pad = np.pad(A_prev, [(padding, padding), (padding, padding), |
接下来,就是梯度的计算了。
1 | for i_h in range(h_o): |
在算导数时,我们应该对照着正向传播的计算,算出每一条计算对导数的贡献。如前文所述,卷积操作只是一个简单的y=wx+b
,把对应的w, x, b
从变量里正确地取出来并做运算即可。
最后,要把这些导数返回。别忘了把填充后的dA_prev
恢复一下。
1 | if padding > 0: |
这里有一个细节:如果padding==0
,则在取切片时范围会变成[0:-0]
,这样会取出一个长度为0
的切片,而不是我们期望的原长度的切片。因此,要特判一下padding<=0
的情况。
单元测试
为了方便地进行单元测试,我使用了pytest这个单元测试库。可以直接pip一键安装:
1 | pip install pytest |
之后就可以用pytest执行我的这份代码,代码里所有以test_
开头的函数会被认为是单元测试的主函数。
1 | pytest dldemos/BasicCNN/np_conv_backward.py |
单元测试函数的定义如下:
1 |
|
@pytest.mark.parametrize
用于设置单元测试参数的可选值。我设置了4组参数,每组参数有2个可选值,经过排列组合后可以生成2^4=16
个单元测试,pytest会自动帮我们执行不同的测试。
在单元测试中,我打算测试conv2d
在各种输入通道数、输出通道数、卷积核大小、步幅、填充数的情况。
测试函数是这样写的:
1 | def test_conv(c_i: int, c_o: int, kernel_size: int, stride: int, padding: str): |
整个测试函数可以分成三部分:变量预处理、前向传播、反向传播。在前向传播和反向传播中,我们要分别用刚编写的卷积核PyTorch中的卷积进行计算,并比较两个运算结果是否相同。
预处理时,我们要创建NumPy和PyTorch的输入。
1 | # Preprocess |
之后是正向传播。计算结果和中间变量会被存入cache
中。
1 | # forward |
最后是反向传播。在那之前,要补充说明一下如何在PyTorch里手动求一些数据的导数。在PyTorch中,各个张量默认是不可训练的。为了让框架知道我们想求哪几个参数的导数,我们要执行张量的required_grad_()
方法,如:
1 | torch_input = torch.from_numpy(np.transpose( |
这样,在正向传播时,PyTorch就会自动把对可训练参数的运算搭成计算图了。
正向传播后,对结果张量调用backward()
即可执行反向传播。但是,PyTorch要求调用backward()
的张量必须是一个标量,也就是它不能是矩阵,不能是任何长度大于1的数据。而这里PyTorch的卷积结果又是一个四维张量。因此,我把PyTorch卷积结果做了求和,得到了一个标量,用它来调用backward()
。
1 | torch_sum = torch.sum(torch_output_tensor) |
这样,就可以用tensor.grad
获取tensor
的导数了,如
1 | torch_weight.grad |
整个反向传播测试的代码如下。
1 | # backward |
再补充一下,在求导时,运算结果的导数是1。因此,新建dZ
时,我用的是np.ones
(全1张量)。同理,PyTorch也会默认运算结果的导数为1,即这里torch_sum.grad==1
。而执行加法运算不会改变导数,所以torch_output_tensor.grad
也是一个全是1的张量,和NumPy的dZ
的值是一模一样的。
写完单元测试函数后,运行前面提到的单元测试命令,pytest就会输出很多测试的结果。
1 | pytest dldemos/BasicCNN/np_conv_backward.py |
如果看到了类似的输出,就说明我们的代码是正确的。
1 | ==== 16 passed in 1.04s ==== |
反向传播的编写思路
通过阅读上面的实现过程,相信大家已经明白如何编写卷积的反向传播了。接下来,我将总结一下实现一般算子的正向、反向传播的思路。无论是用NumPy,还是PyTorch等编程框架,甚至是纯C++,这种思路都是适用的。
一开始,我们要明白,一个算子总共会涉及到这些参数:
- 输入与输出:算子的输入张量和输出张量。正向传播和反向传播的输入输出恰好是相反的。
- 属性:算子的超参数。比如卷积的
stride, padding
。 - 中间变量:前向传播传递给反向传播的变量。
一般情况下,我们应该编写一个算子类。在初始化算子类时,算子的属性就以类属性的形式存储下来了。
在正向传播时,我们按照算子定义直接顺着写下去就行。这个时候,可以先准备好cache
变量,但先不去管它,等写到反向传播的时候再处理。
接着,编写反向传播。由于反向传播和正向传播的运算步骤相似,我们可以直接把正向传播的代码复制一份。在这个基础上,思考每一步正向传播运算产生了哪些导数,对照着写出导数计算的代码即可。这时,我们会用到一些正向传播的中间结果,这下就可以去正向传播代码里填写cache
,在反向传播里取出来了。
最后,写完了算子,一定要做单元测试。如果该算子有现成的实现,用现成的实现来对齐运算结果是最简单的一种实现单元测试的方式。
总结
在这篇文章中,我介绍了以下内容:
- 卷积反向传播的NumPy实现
- 如何用PyTorch手动求导
- 如何编写完整的算子单元测试
- 实现算子正向传播、反向传播的思路
如果你也想把代码基础打牢,一定一定要像这样自己动手从头写一份代码。在写代码,调bug的过程中,一定会有很多收获。
由于现在的编程框架都比较成熟,搞科研时基本不会碰到自己动手写底层算子的情况。但是,如果你想出了一个特别棒的idea,想出了一个全新的神经网络模块,却在写代码时碰到了阻碍,那可就太可惜了。学一学反向传播的实现还是很有用的。
在模型部署中,反向传播可能完全派不上用场。但是,一般框架在实现算子的正向传播时,是会照顾反向传播的。也就是说,如果抛掉反向传播,正向传播的实现或许可以写得更加高效。这样看来,了解反向传播的实现也是很有帮助的。我们可以用这些知识看懂别人的正向传播、反向传播的实现,进而优化代码的性能。
附录:完整代码
1 | from typing import Dict, Tuple |