0%

吴恩达《深度学习专项》代码实战(十一):用 TensorFlow 实现 ResNet 并验证残差连接的有效性

在这篇文章中,我会介绍如何用TensorFlow实现下面4个模型:

  1. ResNet-18
  2. ResNet-18 无跳连
  3. ResNet-50
  4. ResNet-50 无跳连

实现结束后,我会在一个简单的数据集上训练这4个模型。从实验结果中,我们能直观地看出ResNet中残差连接的作用。

项目链接:https://github.com/SingleZombie/DL-Demos

主要代码在dldemos/ResNet/tf_main.py这个文件里。

模型实现

主要结构

ResNet中有跳连的结构,直接用tf.keras.Sequenctial串行模型不太方便。因此,我们要自己把模型的各模块连起来,对应的TensorFlow写法是这样的:

1
2
3
4
5
6
7
8
9
10
11
# Initialize input
input = layers.Input(input_shape)

# Get output
output = ...

# Build model
model = models.Model(inputs=input, outputs=output)
print(model.summary())
return model

layers.Input创建一个输入张量后,就可以对这个张量进行计算,并在最后用tf.keras.models.Model把和该张量相关的计算图搭起来。

接下来,我们看看这个output具体是怎么算出来的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def init_model(
input_shape=(224, 224, 3), model_name='ResNet18', use_shortcut=True):
# Initialize input
input = layers.Input(input_shape)

# Get output
x = layers.Conv2D(64, 7, (2, 2), padding='same')(input)
x = layers.MaxPool2D((3, 3), (2, 2))(x)

if model_name == 'ResNet18':
x = identity_block_2(x, 3, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)
x = convolution_block_2(x, 3, 128, 2, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)
x = convolution_block_2(x, 3, 256, 2, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)
x = convolution_block_2(x, 3, 512, 2, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)
elif model_name == 'ResNet50':

def block_group(x, fs1, fs2, count):
x = convolution_block_3(x, 3, fs1, fs2, 2, use_shortcut)
for i in range(count - 1):
x = identity_block_3(x, 3, fs1, fs2, use_shortcut)
return x

x = block_group(x, 64, 256, 3)
x = block_group(x, 128, 512, 4)
x = block_group(x, 256, 1024, 6)
x = block_group(x, 512, 2048, 3)
else:
raise NotImplementedError(f'No such model {model_name}')

x = layers.AveragePooling2D((2, 2), (2, 2))(x)
x = layers.Flatten()(x)
output = layers.Dense(1, 'sigmoid')(x)

# Build model
model = models.Model(inputs=input, outputs=output)
print(model.summary())
return model

构建模型时,我们需要给出输入张量的形状。同时,这个函数用model_name控制模型的结构,use_shortcut控制是否使用跳连。

1
2
def init_model(
input_shape=(224, 224, 3), model_name='ResNet18', use_shortcut=True):

在ResNet中,主要有两种残差块。

第一种是上图中实线连接的,这种残差块的输入输出形状相同,输入可以直接加到激活函数之前的输出上;第二种是上图中虚线连接的,这种残差块输入输出形状不同,需要用一个1x1卷积调整宽高和通道数。

此外,每种残差块用两种实现方式。

第一种实现方式如上图左半部分所示,这样的残差块由两个通道数相同的3x3卷积构成,只有一个需要决定的通道数;第二种实现方式采用了瓶颈(bottlenect)结构,先用1x1卷积降低了通道数,再进行3x3卷积,共有两个要决定的通道数(第1, 2个卷积和第3个卷积的通道数),如上图右半部分所示。

代码中,我用identity_block_2, identity_block_3分别表示输入输出相同的残差块的两种实现,convolution_block_2, convolution_block_3分别表示输入输出不同的残差块的两种实现。这些代码会在下一小节里给出。

现在,我们来看看该如何用这些模块构成ResNet-18和ResNet-50。首先,我们看一看原论文中这几个ResNet的结构图。

对于这两种架构,它们一开始都要经过一个大卷积层和一个池化层,最后都要做一次平均池化并输入全连接层。不同之处在于中间的卷积层。ResNet-18和ResNet-50使用了实现方式不同且个数不同的卷积层组。

在代码中,开始的大卷积及池化是这样写的:

1
2
x = layers.Conv2D(64, 7, (2, 2), padding='same')(input)
x = layers.MaxPool2D((3, 3), (2, 2))(x)

ResNet-18的实现是:

1
2
3
4
5
6
7
8
9
if model_name == 'ResNet18':
x = identity_block_2(x, 3, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)
x = convolution_block_2(x, 3, 128, 2, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)
x = convolution_block_2(x, 3, 256, 2, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)
x = convolution_block_2(x, 3, 512, 2, use_shortcut)
x = identity_block_2(x, 3, use_shortcut)

其中,identity_block_2的参数分别为输入张量、卷积核边长、是否使用短路。convolution_block_2的参数分别为输入张量、卷积核边长、输出通道数、步幅、是否使用短路。

ResNet-50的实现是:

1
2
3
4
5
6
7
8
9
10
11
elif model_name == 'ResNet50':
def block_group(x, fs1, fs2, count):
x = convolution_block_3(x, 3, fs1, fs2, 2, use_shortcut)
for i in range(count - 1):
x = identity_block_3(x, 3, fs1, fs2, use_shortcut)
return x

x = block_group(x, 64, 256, 3)
x = block_group(x, 128, 512, 4)
x = block_group(x, 256, 1024, 6)
x = block_group(x, 512, 2048, 3)

其中,identity_block_3的参数分别为输入张量、卷积核边长、中间和输出通道数、是否使用短路。convolution_block_3的参数分别为输入张量、卷积核边长、中间和输出通道数、步幅、是否使用短路。

最后是计算分类输出的代码:

1
2
3
x = layers.AveragePooling2D((2, 2), (2, 2))(x)
x = layers.Flatten()(x)
output = layers.Dense(1, 'sigmoid')(x)

残差块实现

1
2
3
4
5
6
7
8
9
10
11
12
def identity_block_2(x, f, use_shortcut=True):
_, _, _, C = x.shape
x_shortcut = x
x = layers.Conv2D(C, f, padding='same')(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.ReLU()(x)
x = layers.Conv2D(C, f, padding='same')(x)
x = layers.BatchNormalization(axis=3)(x)
if use_shortcut:
x = x + x_shortcut
x = layers.ReLU()(x)
return x

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def convolution_block_2(x, f, filters, s: int, use_shortcut=True):
x_shortcut = x
x = layers.Conv2D(filters, f, strides=(s, s), padding='same')(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.ReLU()(x)
x = layers.Conv2D(filters, f, padding='same')(x)
x = layers.BatchNormalization(axis=3)(x)
if use_shortcut:
x_shortcut = layers.Conv2D(filters, 1, strides=(s, s),
padding='valid')(x_shortcut)
x_shortcut = layers.BatchNormalization(axis=3)(x_shortcut)
x = x + x_shortcut
x = layers.ReLU()(x)
return x

1
2
3
4
5
6
7
8
9
10
11
12
13
def identity_block_3(x, f, filters1, filters2, use_shortcut=True):
x_shortcut = x
x = layers.Conv2D(filters1, 1, padding='valid')(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.Conv2D(filters1, f, padding='same')(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.ReLU()(x)
x = layers.Conv2D(filters2, 1, padding='valid')(x)
x = layers.BatchNormalization(axis=3)(x)
if use_shortcut:
x = x + x_shortcut
x = layers.ReLU()(x)
return x

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def convolution_block_3(x, f, filters1, filters2, s: int, use_shortcut=True):
x_shortcut = x
x = layers.Conv2D(filters1, 1, strides=(s, s), padding='valid')(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.Conv2D(filters1, f, padding='same')(x)
x = layers.BatchNormalization(axis=3)(x)
x = layers.ReLU()(x)
x = layers.Conv2D(filters2, 1, padding='valid')(x)
x = layers.BatchNormalization(axis=3)(x)
if use_shortcut:
x_shortcut = layers.Conv2D(filters2, 1, strides=(s, s),
padding='same')(x_shortcut)
x_shortcut = layers.BatchNormalization(axis=3)(x_shortcut)
x = x + x_shortcut
x = layers.ReLU()(x)
return x

这些代码中有一个细节要注意:在convolution_block_3中,stride=2是放在第一个还是第二个卷积层中没有定论。不同框架似乎对此有不同的实现方式。这里是把它放到了第一个1x1卷积里。

实验结果

在这个项目中,我已经准备好了数据集预处理的代码。可以轻松地生成数据集并用TensorFlow训练模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def main():
train_X, train_Y, test_X, test_Y = get_cat_set(
'dldemos/LogisticRegression/data/archive/dataset',
train_size=500,
test_size=50)
print(train_X.shape) # (m, 224, 224, 3)
print(train_Y.shape) # (m , 1)

#model = init_model()
#model = init_model(use_shortcut=False)
model = init_model(model_name='ResNet50')
# model = init_model(model_name='ResNet50', use_shortcut=False)
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])

model.fit(train_X, train_Y, epochs=20, batch_size=16)
model.evaluate(test_X, test_Y)

为了让训练尽快结束,我只训了20个epoch,且使用的数据集比较小。我在ResNet-18中使用了3000个训练样本,ResNet-50中使用了1000个训练样本。数据的多少不影响对比结果,我们只需要知道模型的训练误差,便足以比较这四个模型了。

以下是我在四个实验中得到的结果。

ResNet-18

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Epoch 1/20
63/63 [==============================] - 75s 1s/step - loss: 1.9463 - accuracy: 0.5485
Epoch 2/20
63/63 [==============================] - 71s 1s/step - loss: 0.9758 - accuracy: 0.5423
Epoch 3/20
63/63 [==============================] - 81s 1s/step - loss: 0.8490 - accuracy: 0.5941
Epoch 4/20
63/63 [==============================] - 73s 1s/step - loss: 0.8309 - accuracy: 0.6188
Epoch 5/20
63/63 [==============================] - 72s 1s/step - loss: 0.7375 - accuracy: 0.6402
Epoch 6/20
63/63 [==============================] - 77s 1s/step - loss: 0.7932 - accuracy: 0.6769
Epoch 7/20
63/63 [==============================] - 78s 1s/step - loss: 0.7782 - accuracy: 0.6713
Epoch 8/20
63/63 [==============================] - 76s 1s/step - loss: 0.6272 - accuracy: 0.7147
Epoch 9/20
63/63 [==============================] - 77s 1s/step - loss: 0.6303 - accuracy: 0.7059
Epoch 10/20
63/63 [==============================] - 74s 1s/step - loss: 0.6250 - accuracy: 0.7108
Epoch 11/20
63/63 [==============================] - 73s 1s/step - loss: 0.6065 - accuracy: 0.7142
Epoch 12/20
63/63 [==============================] - 74s 1s/step - loss: 0.5289 - accuracy: 0.7754
Epoch 13/20
63/63 [==============================] - 73s 1s/step - loss: 0.5005 - accuracy: 0.7506
Epoch 14/20
63/63 [==============================] - 73s 1s/step - loss: 0.3961 - accuracy: 0.8141
Epoch 15/20
63/63 [==============================] - 74s 1s/step - loss: 0.4417 - accuracy: 0.8121
Epoch 16/20
63/63 [==============================] - 74s 1s/step - loss: 0.3761 - accuracy: 0.8136
Epoch 17/20
63/63 [==============================] - 73s 1s/step - loss: 0.2764 - accuracy: 0.8809
Epoch 18/20
63/63 [==============================] - 71s 1s/step - loss: 0.2698 - accuracy: 0.8878
Epoch 19/20
63/63 [==============================] - 72s 1s/step - loss: 0.1483 - accuracy: 0.9457
Epoch 20/20
63/63 [==============================] - 72s 1s/step - loss: 0.2495 - accuracy: 0.9079

ResNet-18 无跳连

text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Epoch 1/20
63/63 [==============================] - 63s 963ms/step - loss: 1.4874 - accuracy: 0.5111
Epoch 2/20
63/63 [==============================] - 62s 990ms/step - loss: 0.7654 - accuracy: 0.5386
Epoch 3/20
63/63 [==============================] - 65s 1s/step - loss: 0.6799 - accuracy: 0.6210
Epoch 4/20
63/63 [==============================] - 62s 990ms/step - loss: 0.6891 - accuracy: 0.6086
Epoch 5/20
63/63 [==============================] - 65s 1s/step - loss: 0.7921 - accuracy: 0.5182
Epoch 6/20
63/63 [==============================] - 65s 1s/step - loss: 0.7123 - accuracy: 0.5643
Epoch 7/20
63/63 [==============================] - 64s 1s/step - loss: 0.7071 - accuracy: 0.5173
Epoch 8/20
63/63 [==============================] - 64s 1s/step - loss: 0.6653 - accuracy: 0.6227
Epoch 9/20
63/63 [==============================] - 65s 1s/step - loss: 0.6675 - accuracy: 0.6249
Epoch 10/20
63/63 [==============================] - 64s 1s/step - loss: 0.6959 - accuracy: 0.6130
Epoch 11/20
63/63 [==============================] - 66s 1s/step - loss: 0.6730 - accuracy: 0.6182
Epoch 12/20
63/63 [==============================] - 63s 1s/step - loss: 0.6321 - accuracy: 0.6491
Epoch 13/20
63/63 [==============================] - 63s 992ms/step - loss: 0.6413 - accuracy: 0.6569
Epoch 14/20
63/63 [==============================] - 63s 1s/step - loss: 0.6130 - accuracy: 0.6885
Epoch 15/20
63/63 [==============================] - 62s 988ms/step - loss: 0.6750 - accuracy: 0.6056
Epoch 16/20
63/63 [==============================] - 66s 1s/step - loss: 0.6341 - accuracy: 0.6526
Epoch 17/20
63/63 [==============================] - 68s 1s/step - loss: 0.6384 - accuracy: 0.6676
Epoch 18/20
63/63 [==============================] - 65s 1s/step - loss: 0.5750 - accuracy: 0.6997
Epoch 19/20
63/63 [==============================] - 63s 997ms/step - loss: 0.5932 - accuracy: 0.7094
Epoch 20/20
63/63 [==============================] - 62s 990ms/step - loss: 0.6133 - accuracy: 0.6420

ResNet-50

text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Epoch 1/20
63/63 [==============================] - 72s 1s/step - loss: 3.4660 - accuracy: 0.4970
Epoch 2/20
63/63 [==============================] - 67s 1s/step - loss: 1.3429 - accuracy: 0.5686
Epoch 3/20
63/63 [==============================] - 68s 1s/step - loss: 1.0294 - accuracy: 0.5616
Epoch 4/20
63/63 [==============================] - 68s 1s/step - loss: 0.7920 - accuracy: 0.6186
Epoch 5/20
63/63 [==============================] - 70s 1s/step - loss: 0.6698 - accuracy: 0.6773
Epoch 6/20
63/63 [==============================] - 70s 1s/step - loss: 0.6884 - accuracy: 0.7289
Epoch 7/20
63/63 [==============================] - 70s 1s/step - loss: 0.7144 - accuracy: 0.6399
Epoch 8/20
63/63 [==============================] - 69s 1s/step - loss: 0.7088 - accuracy: 0.6698
Epoch 9/20
63/63 [==============================] - 68s 1s/step - loss: 0.6385 - accuracy: 0.6446
Epoch 10/20
63/63 [==============================] - 69s 1s/step - loss: 0.5389 - accuracy: 0.7417
Epoch 11/20
63/63 [==============================] - 71s 1s/step - loss: 0.4954 - accuracy: 0.7832
Epoch 12/20
63/63 [==============================] - 73s 1s/step - loss: 0.4489 - accuracy: 0.7782
Epoch 13/20
63/63 [==============================] - 69s 1s/step - loss: 0.3987 - accuracy: 0.8257
Epoch 14/20
63/63 [==============================] - 72s 1s/step - loss: 0.3228 - accuracy: 0.8519
Epoch 15/20
63/63 [==============================] - 70s 1s/step - loss: 0.2089 - accuracy: 0.9235
Epoch 16/20
63/63 [==============================] - 69s 1s/step - loss: 0.4766 - accuracy: 0.7756
Epoch 17/20
63/63 [==============================] - 75s 1s/step - loss: 0.2148 - accuracy: 0.9181
Epoch 18/20
63/63 [==============================] - 70s 1s/step - loss: 0.3086 - accuracy: 0.8623
Epoch 19/20
63/63 [==============================] - 69s 1s/step - loss: 0.3544 - accuracy: 0.8732
Epoch 20/20
63/63 [==============================] - 70s 1s/step - loss: 0.0796 - accuracy: 0.9704

ResNet-50 无跳连

text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Epoch 1/20
63/63 [==============================] - 60s 882ms/step - loss: 1.2093 - accuracy: 0.5034
Epoch 2/20
63/63 [==============================] - 56s 892ms/step - loss: 0.8433 - accuracy: 0.4861
Epoch 3/20
63/63 [==============================] - 59s 931ms/step - loss: 0.7512 - accuracy: 0.5235
Epoch 4/20
63/63 [==============================] - 62s 991ms/step - loss: 0.7395 - accuracy: 0.4887
Epoch 5/20
63/63 [==============================] - 62s 990ms/step - loss: 0.7770 - accuracy: 0.5316
Epoch 6/20
63/63 [==============================] - 60s 945ms/step - loss: 0.7408 - accuracy: 0.4947
Epoch 7/20
63/63 [==============================] - 67s 1s/step - loss: 0.7345 - accuracy: 0.5434
Epoch 8/20
63/63 [==============================] - 62s 984ms/step - loss: 0.7214 - accuracy: 0.5605
Epoch 9/20
63/63 [==============================] - 60s 950ms/step - loss: 0.7770 - accuracy: 0.4784
Epoch 10/20
63/63 [==============================] - 60s 956ms/step - loss: 0.7171 - accuracy: 0.5203
Epoch 11/20
63/63 [==============================] - 63s 994ms/step - loss: 0.7045 - accuracy: 0.4921
Epoch 12/20
63/63 [==============================] - 63s 1s/step - loss: 0.6884 - accuracy: 0.5430
Epoch 13/20
63/63 [==============================] - 60s 958ms/step - loss: 0.7333 - accuracy: 0.5278
Epoch 14/20
63/63 [==============================] - 61s 966ms/step - loss: 0.7050 - accuracy: 0.5106
Epoch 15/20
63/63 [==============================] - 59s 943ms/step - loss: 0.6958 - accuracy: 0.5622
Epoch 16/20
63/63 [==============================] - 60s 954ms/step - loss: 0.7398 - accuracy: 0.5172
Epoch 17/20
63/63 [==============================] - 69s 1s/step - loss: 0.7104 - accuracy: 0.5023
Epoch 18/20
63/63 [==============================] - 74s 1s/step - loss: 0.7411 - accuracy: 0.4747
Epoch 19/20
63/63 [==============================] - 67s 1s/step - loss: 0.7056 - accuracy: 0.4706
Epoch 20/20
63/63 [==============================] - 81s 1s/step - loss: 0.7901 - accuracy: 0.4898

对比ResNet-18和ResNet-50,可以看出,ResNet-50的拟合能力确实更强一些。

对比无跳连的ResNet-18和ResNet-50,可以看出,ResNet-50的拟合能力反而逊于ResNet-18。这符合ResNet的初衷,如果不加残差连接的话,过深的网络反而会因为梯度问题而有更高的训练误差。

此外,不同模型的训练速度也值得一讲。在训练数据量减少到原来的1/3后,ResNet-50和ResNet-18的训练速度差不多。ResNet-50看上去比ResNet-18多了很多层,网络中间也使用了通道数很大的卷积,但整体的参数量并没有增大多少,这多亏了能降低运算量的瓶颈结构。

总结

在这篇文章中,我展示了ResNet-18和ResNet-50的TensorFlow实现。这份代码包括了经典ResNet中两种残差块的两种实现,完整地复现了原论文的模型模块。同时,经实验分析,我验证了ResNet残差连接的有效性。

未来我还会写一篇ResNet的PyTorch实现,并附上论文的详细解读。