0%

PyTorch 注意力模型实现详解(以简单的机器翻译为例)

Transformer中的“注意力”最早来自于NLP里的注意力模型。通过动手实现一遍注意力模型,我们能够更深刻地理解注意力的原理,以便于学习Transformer等后续那些基于注意力的模型。在这篇文章中,我将分享如何用PyTorch的基本API实现注意力模型,完成一个简单的机器翻译项目——把各种格式的日期“翻译”成统一格式的日期。

有关机器翻译、注意力模型相关知识请参考我之前的文章。如序列模型与注意力机制

项目网址:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/attention

知识背景

注意力模型发源自机器翻译任务。最早,基于RNN的机器翻译模型都采用如下的架构:

前半部分的RNN只有输入,后半部分的RNN只有输出。两个部分通过一个简单的隐状态来传递信息。把隐状态看成输入信息的一种编码的话,前半部分可以叫做“编码器”,后半部分可以叫做“解码器”。这种架构因而被称为“编码器-解码器”架构。

这种架构在翻译短句子时确实有效,但面对长文章时就捉襟见肘了。使用“编码器-解码器”架构时,无论输入有多长,输入都会被压缩成一个简短的编码。也就是说,模型要一次性阅读完所有输入,再一次性输出所有翻译。这显然不是一种好的方法。联想一下,我们人类在翻译时,一般会读一句话,翻译一句话,读一句话,翻译一句话。基于这种思想,有人提出了注意力模型。注意力模型能够有效地翻译长文章。

在注意力模型中,编码器和解码器以另一种方式连接在一起。在完成编码后,解码器会以不同的权重去各个编码输出中取出相关信息,也就是以不同的“注意力”去关注输入信息。

具体来说,注意力模型的结构如下。

对于每一轮的输出$\hat{y}^{< t >}$,它的解码RNN的输入由上一轮输出$\hat{y}^{< t - 1>}$和注意力上下文$c^{< t >}$拼接而成。注意力上下文$c^{< t >}$,就是所有输入的编码RNN的隐变量$a^{< t >}$的一个加权平均数。这里加权平均数的权重$\alpha$就是该输出对每一个输入的注意力。每一个$\alpha$由编码RNN本轮状态$a^{< t’ >}$和解码RNN上一轮状态$s^{< t - 1 >}$决定。这两个输入会被送入一个简单的全连接网络,输出权重$e$(一个实数)。所有输入元素的$e$经过一个softmax输出$\alpha$。

日期翻译任务及其数据集

为了简化项目的实现,我们来完成一个简单的日期翻译任务。在这个任务中,输入是各式各样的日期,输出是某一个标准格式的日期。比如:

input output
Nov 23, 1999 1999-11-23
3 April 2005 2005-04-03
14/01/1989 1989-01-14
Thursday, February 7, 1985 1985-02-07

我们可以自己动手用Python生成数据集。在生成数据集时,我们要用到随机生成日期的faker库和格式化日期的babel库。

1
pip install faker babel

运行下面这段代码,我们可以生成不同格式的日期。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import random

from babel.dates import format_date
from faker import Faker

faker = Faker()
format_list = [
'short', 'medium', 'long', 'full', 'd MMM YYY', 'd MMMM YYY', 'dd/MM/YYY',
'dd-MM-YYY', 'EE d, MMM YYY', 'EEEE d, MMMM YYY'
]

if __name__ == '__main__':
for format in format_list:
date_obj = faker.date_object()
print(f'{format}:', date_obj,
format_date(date_obj, format=format, locale='en'))
text
1
2
3
4
5
6
7
8
9
10
11
Possible output:
short: 1986-02-25 2/25/86
medium: 1979-08-05 Aug 5, 1979
long: 1971-12-15 December 15, 1971
full: 2017-02-14 Tuesday, February 14, 2017
d MMM YYY: 1984-02-21 21 Feb 1984
d MMMM YYY: 2011-06-22 22 June 2011
dd/MM/YYY: 1991-08-02 02/08/1991
dd-MM-YYY: 1987-06-12 12-06-1987
EE d, MMM YYY: 1986-11-02 Sun 2, Nov 1986
EEEE d, MMMM YYY: 1996-01-26 Friday 26, January 1996

Faker()是生成随机数据的代理类,用它的date_object()方法可以随机生成一个日期字符串date_obj。这个日期就是我们期望的标准格式。而通过使用format_date函数,我们可以通过改变该函数的format参数来得到格式不一样的日期字符串。各种格式的日期示例可以参考上面的输出。

利用这些工具函数,我们可以编写下面这些生成、读取数据集的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def generate_date():
format = random.choice(format_list)
date_obj = faker.date_object()
formated_date = format_date(date_obj, format=format, locale='en')
return formated_date, date_obj


def generate_date_data(count, filename):
with open(filename, 'w') as fp:
for _ in range(count):
formated_date, date_obj = generate_date()
fp.write(f'{formated_date}\t{date_obj}\n')


def load_date_data(filename):
with open(filename, 'r') as fp:
lines = fp.readlines()
return [line.strip('\n').split('\t') for line in lines]


generate_date_data(50000, 'dldemos/attention/train.txt')
generate_date_data(10000, 'dldemos/attention/test.txt')

注意力模型

在这个项目中,最难的部分是注意力模型的实现,即如何把上一节那个结构图用PyTorch描述出来。所有模型实现的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

from dldemos.attention.dataset import generate_date, load_date_data


EMBEDDING_LENGTH = 128
OUTPUT_LENGTH = 10

class AttentionModel(nn.Module):
def __init__(self,
embeding_dim=32,
encoder_dim=32,
decoder_dim=32,
dropout_rate=0.5):
super().__init__()
self.drop = nn.Dropout(dropout_rate)
self.embedding = nn.Embedding(EMBEDDING_LENGTH, embeding_dim)
self.attention_linear = nn.Linear(2 * encoder_dim + decoder_dim, 1)
self.softmax = nn.Softmax(-1)
self.encoder = nn.LSTM(embeding_dim,
encoder_dim,
1,
batch_first=True,
bidirectional=True)
self.decoder = nn.LSTM(EMBEDDING_LENGTH + 2 * encoder_dim,
decoder_dim,
1,
batch_first=True)
self.output_linear = nn.Linear(decoder_dim, EMBEDDING_LENGTH)
self.decoder_dim = decoder_dim

def forward(self, x: torch.Tensor, n_output: int = OUTPUT_LENGTH):
# x: [batch, n_sequence, EMBEDDING_LENGTH]
batch, n_squence = x.shape[0:2]

# x: [batch, n_sequence, embeding_dim]
x = self.drop(self.embedding(x))

# a: [batch, n_sequence, hidden]
a, _ = self.encoder(x)

# prev_s: [batch, n_squence=1, hidden]
# prev_y: [batch, n_squence=1, EMBEDDING_LENGTH]
# y: [batch, n_output, EMBEDDING_LENGTH]
prev_s = x.new_zeros(batch, 1, self.decoder_dim)
prev_y = x.new_zeros(batch, 1, EMBEDDING_LENGTH)
y = x.new_empty(batch, n_output, EMBEDDING_LENGTH)
tmp_states = None
for i_output in range(n_output):
# repeat_s: [batch, n_squence, hidden]
repeat_s = prev_s.repeat(1, n_squence, 1)
# attention_input: [batch * n_sequence, hidden_s + hidden_a]
attention_input = torch.cat((repeat_s, a),
2).reshape(batch * n_squence, -1)
alpha = self.softmax(self.attention_linear(attention_input))
c = torch.sum(a * alpha.reshape(batch, n_squence, 1), 1)
c = c.unsqueeze(1)
decoder_input = torch.cat((prev_y, c), 2)

if tmp_states is None:
prev_s, tmp_states = self.decoder(decoder_input)
else:
prev_s, tmp_states = self.decoder(decoder_input, tmp_states)

prev_y = self.output_linear(prev_s)
y[:, i_output] = prev_y.squeeze(1)
return y

让我们把这份实现一点一点过一遍。

在实现前,我们要准备一些常量。我们首先要决定“词汇表”的大小。在日期翻译任务中,输入和输出应当看成是字符序列。字符最多有128个,因此我们可以令“词汇表”大小为128。

1
EMBEDDING_LENGTH = 128

在我们这个任务中,输出序列的长度是固定的。对于yyyy-mm-dd这个日期字符串,其长度为10。我们要把这个常量也准备好。

1
OUTPUT_LENGTH = 10

接下来是模型的实现。先看__init__里的结构定义。一开始,按照RNN模型的惯例,我们要让输入过Dropout和嵌入层。对于单词序列,使用预训练的单词嵌入会好一点。然而,我们这个项目用的是字符序列,直接定义一个可学习的嵌入层即可。

1
2
self.drop = nn.Dropout(dropout_rate)
self.embedding = nn.Embedding(EMBEDDING_LENGTH, embeding_dim)

接下来是编码器和解码器。在注意力模型中,编码器和解码器是两个不同的RNN。为了充分利用输入信息,可以把双向RNN当作编码器。而由于机器翻译是一个生成答案的任务,每轮生成元素时需要用到上一轮生成出来的元素,解码器必须是一个单向RNN。在本项目中,我使用的RNN是LSTM。模块定义代码如下:

1
2
3
4
5
6
7
8
9
self.encoder = nn.LSTM(embeding_dim,
encoder_dim,
1,
batch_first=True,
bidirectional=True)
self.decoder = nn.LSTM(EMBEDDING_LENGTH + 2 * encoder_dim,
decoder_dim,
1,
batch_first=True)

这里要注意一下这两个模块的输入通道数。encoder的输入来自嵌入层,因此是embeding_dim,这个很好理解。decoder的输入通道则需要计算一番了。decoder的输入由模型上一轮的输出和注意力输出拼接而成。模型每轮会输出一个字符,字符的通道数是“词汇表”大小,即EMBEDDING_LENGTH。注意力的输出是encoder的隐变量的加权和,因此其通道数和encoder的隐变量一致。encoder是双向RNN,其隐变量的通道数是2 * encoder_dim。最终,decoder的输入通道数应是EMBEDDING_LENGTH + 2 * encoder_dim

在注意力模块中,解码RNN对各编码RNN的注意力由一个线性层计算而得。该线性层的输入由解码RNN和编码RNN的隐变量拼接而成,因此其通道数为2 * encoder_dim + decoder_dim;该线性层的输出是注意力权重——一个实数。

1
self.attention_linear = nn.Linear(2 * encoder_dim + decoder_dim, 1)

解码结束后,还需要经过一个线性层才能输出结果。

1
self.output_linear = nn.Linear(decoder_dim, EMBEDDING_LENGTH)

看完了__init__,来看看forward里各模块是怎么连接起来的。

机器翻译其实是一个生成序列的任务。一般情况下,生成序列的长度是不确定的,需要用一些额外的技巧来选择最佳的输出序列。为了简化实现,在这个项目中,我们生成一个固定长度的输出序列。该长度应该在forward的参数里指定。因此,forward的参数如下:

1
def forward(self, x: torch.Tensor, n_output: int = OUTPUT_LENGTH):

一开始,先获取一些形状信息。

1
2
# x: [batch, n_sequence, EMBEDDING_LENGTH]
batch, n_squence = x.shape[0:2]

输入通过嵌入层和dropout层。

1
2
# x: [batch, n_sequence, embeding_dim]
x = self.drop(self.embedding(x))

再通过编码器,得到编码隐状态a

1
2
# a: [batch, n_sequence, hidden]
a, _ = self.encoder(x)

接下来,要用for循环输出每一轮的结果了。在此之前,我们要准备一些中间变量:用于计算注意力的解码器上一轮状态prev_s,用于解码器输入的上一轮输出prev_y,输出张量y。另外,由于我们要在循环中手动调用decoder完成每一轮的计算,还需要保存decoder的所有中间变量tmp_states

1
2
3
4
5
6
7
# prev_s: [batch, n_squence=1, hidden]
# prev_y: [batch, n_squence=1, EMBEDDING_LENGTH]
# y: [batch, n_output, EMBEDDING_LENGTH]
prev_s = x.new_zeros(batch, 1, self.decoder_dim)
prev_y = x.new_zeros(batch, 1, EMBEDDING_LENGTH)
y = x.new_empty(batch, n_output, EMBEDDING_LENGTH)
tmp_states = None

在每一轮输出中,我们首先要获得当前的解码器对于每一个输入的注意力alpha。每一个alpha由解码器上一轮状态prev_s和编码器本轮状态决定(一个全连接层+softmax)。为了充分利用并行计算,我们可以把所有alpha的计算打包成batch,一步做完。

注意,这里的全连接层+softmax和普通的全连接网络不太一样。这里全连接层的输出通道数是1,会对n组输入做n次计算,得到n个结果,再对n个结果做softmax。我们之所以能一次得到n个结果,是巧妙地把n放到了batch那一维。

1
2
3
4
5
6
7
8
9
10
11
for i_output in range(n_output):
# repeat_s: [batch, n_squence, hidden]
repeat_s = prev_s.repeat(1, n_squence, 1)
# attention_input: [batch * n_sequence, hidden_s + hidden_a]
attention_input = torch.cat((repeat_s, a),
2).reshape(batch * n_squence, -1)
# x: [batch * n_sequence, 1]
x = self.attention_linear(attention_input)
# x: [batch, n_sequence]
x = x.reshape(batch, n_squence)
alpha = self.softmax(x)

求出了注意力alpha后,就可以用它来算出注意力上下文c了。

1
c = torch.sum(a * alpha.reshape(batch, n_squence, 1), 1)

之后,我们把c和上一轮输出prev_y拼一下,作为解码器的输出。

1
2
c = c.unsqueeze(1)
decoder_input = torch.cat((prev_y, c), 2)

再调用解码器即可。这里我利用PyTorch的机制偷了个懒。理论上解码器第一轮的状态应该是全零张量,我们应该初始化两个全零张量作为LSTM的初始状态。但是,在PyTorch里,如果调用RNN时不传入状态,就默认会使用全零状态。因此,在第一轮调用时,我们可以不去传状态参数。

1
2
3
4
if tmp_states is None:
prev_s, tmp_states = self.decoder(decoder_input)
else:
prev_s, tmp_states = self.decoder(decoder_input, tmp_states)

最后,用线性层算出这轮的输出,维护输出变量y。循环结束后,返回y

1
2
3
    prev_y = self.output_linear(prev_s)
y[:, i_output] = prev_y.squeeze(1)
return y

训练、测试、推理

写完了最核心的注意力模型,剩下的代码就比较简单了。

首先,我们要准备一个Dataset类。这个类可以读取输入、输出字符串,并把它们转换成整形数组。字符和整形数字间的映射非常暴力,一个字符的序号就是该字符的ASCII码。这样写比较简洁,但由于很多字符是用不到的,会浪费一些计算性能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def stoi(str):
return torch.LongTensor([ord(char) for char in str])


def itos(arr):
return ''.join([chr(x) for x in arr])


class DateDataset(Dataset):
def __init__(self, lines):
self.lines = lines

def __len__(self):
return len(self.lines)

def __getitem__(self, index):
line = self.lines[index]

return stoi(line[0]), stoi(line[1])

准备好DataSet后,就可以生成DataLoader了。在序列任务中,各个样本的序列长度可能是不一致的。我们可以用PyTorch的pad_sequence对长度不足的样本进行0填充,使得一个batch里的所有样本都有着同样的序列长度。

1
2
3
4
5
6
7
8
9
10
11
def get_dataloader(filename):

def collate_fn(batch):
x, y = zip(*batch)
x_pad = pad_sequence(x, batch_first=True)
y_pad = pad_sequence(y, batch_first=True)
return x_pad, y_pad

lines = load_date_data(filename)
dataset = DateDataset(lines)
return DataLoader(dataset, 32, collate_fn=collate_fn)

这里要稍微注意一下,pad_sequence默认会做0填充,0填充在我们的项目里是合理的。在我们定义的“词汇表”里,0对应的是ASCII里的0号字符,这个字符不会和其他字符起冲突。

做好一切准备工作后,可以开始训练模型了。训练模型的代码非常常规,定义好Adam优化器、交叉熵误差,跑完模型后reshape一下算出loss再反向传播即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def main():
device = 'cuda:0'
train_dataloader = get_dataloader('dldemos/attention/train.txt')
test_dataloader = get_dataloader('dldemos/attention/test.txt')

model = AttentionModel().to(device)

# train

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
citerion = torch.nn.CrossEntropyLoss()
for epoch in range(20):

loss_sum = 0
dataset_len = len(train_dataloader.dataset)

for x, y in train_dataloader:
x = x.to(device)
y = y.to(device)
hat_y = model(x)
n, Tx, _ = hat_y.shape
hat_y = torch.reshape(hat_y, (n * Tx, -1))
label_y = torch.reshape(y, (n * Tx, ))
loss = citerion(hat_y, label_y)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()

loss_sum += loss * n

print(f'Epoch {epoch}. loss: {loss_sum / dataset_len}')

torch.save(model.state_dict(), 'dldemos/attention/model.pth')

训练完模型后,我们可以测试一下模型在测试集上的正确率。在日期翻译任务中,我们可以把“正确”定义为输出和真值一模一样。比如一条日期的真值是”2000-01-01”,模型的输出必须也是”2000-01-01”才能说这个输出是正确的。编写并行化计算正确率的代码稍有难度。

模型的输出hat_y表示各个字符的出现概率。我们先用prediction = torch.argmax(hat_y, 2)把序列里每个概率最大的字符作为模型预测的字符。现在,我们要用并行化编程判断每对序列(整形标签数组)predition[i]y[i]是否相等(注意,preditiony是带了batch那个维度的)。这里,我们可以让predition[i]y[i]做减法再求和。仅当这个和为0时,我们才能说predition[i]y[i]完全相等。通过这样一种曲折的实现方法,我们可以并行地算出正确率。

也许有更方便的API可以完成这个逻辑判断,但去网上搜索这么复杂的一个需求太麻烦了,我偷了个懒。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# test
model.load_state_dict(torch.load('dldemos/attention/model.pth'))

accuracy = 0
dataset_len = len(test_dataloader.dataset)

for x, y in test_dataloader:
x = x.to(device)
y = y.to(device)
hat_y = model(x)
prediction = torch.argmax(hat_y, 2)
score = torch.where(torch.sum(prediction - y, -1) == 0, 1, 0)
accuracy += torch.sum(score)

print(f'Accuracy: {accuracy / dataset_len}')

最后,我们也可以临时生成几个测试用例,输出模型的预测结果。

1
2
3
4
5
6
7
8
9
# inference
for _ in range(5):
x, y = generate_date()
origin_x = x
x = stoi(x).unsqueeze(0).to(device)
hat_y = model(x)
hat_y = hat_y.squeeze(0).argmax(1)
hat_y = itos(hat_y)
print(f'input: {origin_x}, prediction: {hat_y}, gt: {y}')

训练20-30个epoch后,模型差不多就收敛了。我训练的模型在测试集上的正确率约有98%。下面是随机测试用例的推理结果,可以看出模型的判断确实很准确。

text
1
2
3
4
5
input: 4 November 1988, prediction: 1988-11-04, gt: 1988-11-04
input: Friday 26, March 2021, prediction: 2021-03-26, gt: 2021-03-26
input: Saturday 2, December 1989, prediction: 1989-12-02, gt: 1989-12-02
input: 15/10/1971, prediction: 1971-10-15, gt: 1971-10-15
input: Mon 9, Oct 1989, prediction: 1989-10-09, gt: 1989-10-09

总结

在这篇文章中,我展示了一个用PyTorch编写的注意力模型,它用于完成日期翻译任务。在这个项目中,最重要的是注意力模型的编写。如今,注意力模型已经不是功能最强大的模型架构了。不过,通过动手实现这个模型,我们可以对注意力机制有着更深刻的认识,有助于理解那些更先进的模型。

欢迎关注我的其它发布渠道