Transformer中的“注意力”最早来自于NLP里的注意力模型。通过动手实现一遍注意力模型,我们能够更深刻地理解注意力的原理,以便于学习Transformer等后续那些基于注意力的模型。在这篇文章中,我将分享如何用PyTorch的基本API实现注意力模型,完成一个简单的机器翻译项目——把各种格式的日期“翻译”成统一格式的日期。
有关机器翻译、注意力模型相关知识请参考我之前的文章。如序列模型与注意力机制 。
项目网址:https://github.com/SingleZombie/DL-Demos/tree/master/dldemos/attention
知识背景 注意力模型发源自机器翻译任务。最早,基于RNN的机器翻译模型都采用如下的架构:
前半部分的RNN只有输入,后半部分的RNN只有输出。两个部分通过一个简单的隐状态来传递信息。把隐状态看成输入信息的一种编码的话,前半部分可以叫做“编码器”,后半部分可以叫做“解码器”。这种架构因而被称为“编码器-解码器”架构。
这种架构在翻译短句子时确实有效,但面对长文章时就捉襟见肘了。使用“编码器-解码器”架构时,无论输入有多长,输入都会被压缩成一个简短的编码。也就是说,模型要一次性阅读完所有输入,再一次性输出所有翻译。这显然不是一种好的方法。联想一下,我们人类在翻译时,一般会读一句话,翻译一句话,读一句话,翻译一句话。基于这种思想,有人提出了注意力模型。注意力模型能够有效地翻译长文章。
在注意力模型中,编码器和解码器以另一种方式连接在一起。在完成编码后,解码器会以不同的权重去各个编码输出中取出相关信息,也就是以不同的“注意力”去关注输入信息。
具体来说,注意力模型的结构如下。
对于每一轮的输出$\hat{y}^{< t >}$,它的解码RNN的输入由上一轮输出$\hat{y}^{< t - 1>}$和注意力上下文$c^{< t >}$拼接而成。注意力上下文$c^{< t >}$,就是所有输入的编码RNN的隐变量$a^{< t >}$的一个加权平均数。这里加权平均数的权重$\alpha$就是该输出对每一个输入的注意力。每一个$\alpha$由编码RNN本轮状态$a^{< t’ >}$和解码RNN上一轮状态$s^{< t - 1 >}$决定。这两个输入会被送入一个简单的全连接网络,输出权重$e$(一个实数)。所有输入元素的$e$经过一个softmax输出$\alpha$。
日期翻译任务及其数据集 为了简化项目的实现,我们来完成一个简单的日期翻译任务。在这个任务中,输入是各式各样的日期,输出是某一个标准格式的日期。比如:
input
output
Nov 23, 1999
1999-11-23
3 April 2005
2005-04-03
14/01/1989
1989-01-14
Thursday, February 7, 1985
1985-02-07
我们可以自己动手用Python生成数据集。在生成数据集时,我们要用到随机生成日期的faker
库和格式化日期的babel
库。
运行下面这段代码,我们可以生成不同格式的日期。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import randomfrom babel.dates import format_datefrom faker import Fakerfaker = Faker() format_list = [ 'short' , 'medium' , 'long' , 'full' , 'd MMM YYY' , 'd MMMM YYY' , 'dd/MM/YYY' , 'dd-MM-YYY' , 'EE d, MMM YYY' , 'EEEE d, MMMM YYY' ] if __name__ == '__main__' : for format in format_list: date_obj = faker.date_object() print (f'{format } :' , date_obj, format_date(date_obj, format =format , locale='en' ))
text 1 2 3 4 5 6 7 8 9 10 11 Possible output: short: 1986-02-25 2/25/86 medium: 1979-08-05 Aug 5, 1979 long: 1971-12-15 December 15, 1971 full: 2017-02-14 Tuesday, February 14, 2017 d MMM YYY: 1984-02-21 21 Feb 1984 d MMMM YYY: 2011-06-22 22 June 2011 dd/MM/YYY: 1991-08-02 02/08/1991 dd-MM-YYY: 1987-06-12 12-06-1987 EE d, MMM YYY: 1986-11-02 Sun 2, Nov 1986 EEEE d, MMMM YYY: 1996-01-26 Friday 26, January 1996
Faker()
是生成随机数据的代理类,用它的date_object()
方法可以随机生成一个日期字符串date_obj
。这个日期就是我们期望的标准格式。而通过使用format_date
函数,我们可以通过改变该函数的format
参数来得到格式不一样的日期字符串。各种格式的日期示例可以参考上面的输出。
利用这些工具函数,我们可以编写下面这些生成、读取数据集的函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 def generate_date (): format = random.choice(format_list) date_obj = faker.date_object() formated_date = format_date(date_obj, format =format , locale='en' ) return formated_date, date_obj def generate_date_data (count, filename ): with open (filename, 'w' ) as fp: for _ in range (count): formated_date, date_obj = generate_date() fp.write(f'{formated_date} \t{date_obj} \n' ) def load_date_data (filename ): with open (filename, 'r' ) as fp: lines = fp.readlines() return [line.strip('\n' ).split('\t' ) for line in lines] generate_date_data(50000 , 'dldemos/attention/train.txt' ) generate_date_data(10000 , 'dldemos/attention/test.txt' )
注意力模型 在这个项目中,最难的部分是注意力模型的实现,即如何把上一节那个结构图用PyTorch描述出来。所有模型实现的代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 import torchimport torch.nn as nnfrom torch.nn.utils.rnn import pad_sequencefrom torch.utils.data import DataLoader, Datasetfrom dldemos.attention.dataset import generate_date, load_date_dataEMBEDDING_LENGTH = 128 OUTPUT_LENGTH = 10 class AttentionModel (nn.Module ): def __init__ (self, embeding_dim=32 , encoder_dim=32 , decoder_dim=32 , dropout_rate=0.5 ): super ().__init__() self.drop = nn.Dropout(dropout_rate) self.embedding = nn.Embedding(EMBEDDING_LENGTH, embeding_dim) self.attention_linear = nn.Linear(2 * encoder_dim + decoder_dim, 1 ) self.softmax = nn.Softmax(-1 ) self.encoder = nn.LSTM(embeding_dim, encoder_dim, 1 , batch_first=True , bidirectional=True ) self.decoder = nn.LSTM(EMBEDDING_LENGTH + 2 * encoder_dim, decoder_dim, 1 , batch_first=True ) self.output_linear = nn.Linear(decoder_dim, EMBEDDING_LENGTH) self.decoder_dim = decoder_dim def forward (self, x: torch.Tensor, n_output: int = OUTPUT_LENGTH ): batch, n_squence = x.shape[0 :2 ] x = self.drop(self.embedding(x)) a, _ = self.encoder(x) prev_s = x.new_zeros(batch, 1 , self.decoder_dim) prev_y = x.new_zeros(batch, 1 , EMBEDDING_LENGTH) y = x.new_empty(batch, n_output, EMBEDDING_LENGTH) tmp_states = None for i_output in range (n_output): repeat_s = prev_s.repeat(1 , n_squence, 1 ) attention_input = torch.cat((repeat_s, a), 2 ).reshape(batch * n_squence, -1 ) alpha = self.softmax(self.attention_linear(attention_input)) c = torch.sum (a * alpha.reshape(batch, n_squence, 1 ), 1 ) c = c.unsqueeze(1 ) decoder_input = torch.cat((prev_y, c), 2 ) if tmp_states is None : prev_s, tmp_states = self.decoder(decoder_input) else : prev_s, tmp_states = self.decoder(decoder_input, tmp_states) prev_y = self.output_linear(prev_s) y[:, i_output] = prev_y.squeeze(1 ) return y
让我们把这份实现一点一点过一遍。
在实现前,我们要准备一些常量。我们首先要决定“词汇表”的大小。在日期翻译任务中,输入和输出应当看成是字符序列。字符最多有128个,因此我们可以令“词汇表”大小为128。
在我们这个任务中,输出序列的长度是固定的。对于yyyy-mm-dd
这个日期字符串,其长度为10。我们要把这个常量也准备好。
接下来是模型的实现。先看__init__
里的结构定义。一开始,按照RNN模型的惯例,我们要让输入过Dropout和嵌入层。对于单词序列,使用预训练的单词嵌入会好一点。然而,我们这个项目用的是字符序列,直接定义一个可学习的嵌入层即可。
1 2 self.drop = nn.Dropout(dropout_rate) self.embedding = nn.Embedding(EMBEDDING_LENGTH, embeding_dim)
接下来是编码器和解码器。在注意力模型中,编码器和解码器是两个不同的RNN。为了充分利用输入信息,可以把双向RNN当作编码器。而由于机器翻译是一个生成答案的任务,每轮生成元素时需要用到上一轮生成出来的元素,解码器必须是一个单向RNN。在本项目中,我使用的RNN是LSTM。模块定义代码如下:
1 2 3 4 5 6 7 8 9 self.encoder = nn.LSTM(embeding_dim, encoder_dim, 1 , batch_first=True , bidirectional=True ) self.decoder = nn.LSTM(EMBEDDING_LENGTH + 2 * encoder_dim, decoder_dim, 1 , batch_first=True )
这里要注意一下这两个模块的输入通道数。encoder
的输入来自嵌入层,因此是embeding_dim
,这个很好理解。decoder
的输入通道则需要计算一番了。decoder
的输入由模型上一轮的输出和注意力输出拼接而成。模型每轮会输出一个字符,字符的通道数是“词汇表”大小,即EMBEDDING_LENGTH
。注意力的输出是encoder
的隐变量的加权和,因此其通道数和encoder
的隐变量一致。encoder
是双向RNN,其隐变量的通道数是2 * encoder_dim
。最终,decoder
的输入通道数应是EMBEDDING_LENGTH + 2 * encoder_dim
。
在注意力模块中,解码RNN对各编码RNN的注意力由一个线性层计算而得。该线性层的输入由解码RNN和编码RNN的隐变量拼接而成,因此其通道数为2 * encoder_dim + decoder_dim
;该线性层的输出是注意力权重——一个实数。
1 self.attention_linear = nn.Linear(2 * encoder_dim + decoder_dim, 1 )
解码结束后,还需要经过一个线性层才能输出结果。
1 self.output_linear = nn.Linear(decoder_dim, EMBEDDING_LENGTH)
看完了__init__
,来看看forward
里各模块是怎么连接起来的。
机器翻译其实是一个生成序列的任务。一般情况下,生成序列的长度是不确定的,需要用一些额外的技巧来选择最佳的输出序列。为了简化实现,在这个项目中,我们生成一个固定长度的输出序列。该长度应该在forward
的参数里指定。因此,forward
的参数如下:1 def forward (self, x: torch.Tensor, n_output: int = OUTPUT_LENGTH ):
一开始,先获取一些形状信息。
1 2 batch, n_squence = x.shape[0 :2 ]
输入通过嵌入层和dropout层。
1 2 x = self.drop(self.embedding(x))
再通过编码器,得到编码隐状态a
。
接下来,要用for循环输出每一轮的结果了。在此之前,我们要准备一些中间变量:用于计算注意力的解码器上一轮状态prev_s
,用于解码器输入的上一轮输出prev_y
,输出张量y
。另外,由于我们要在循环中手动调用decoder
完成每一轮的计算,还需要保存decoder
的所有中间变量tmp_states
。
1 2 3 4 5 6 7 prev_s = x.new_zeros(batch, 1 , self.decoder_dim) prev_y = x.new_zeros(batch, 1 , EMBEDDING_LENGTH) y = x.new_empty(batch, n_output, EMBEDDING_LENGTH) tmp_states = None
在每一轮输出中,我们首先要获得当前的解码器对于每一个输入的注意力alpha
。每一个alpha
由解码器上一轮状态prev_s
和编码器本轮状态决定(一个全连接层+softmax)。为了充分利用并行计算,我们可以把所有alpha的计算打包成batch,一步做完。
注意,这里的全连接层+softmax和普通的全连接网络不太一样。这里全连接层的输出通道数是1,会对n组输入做n次计算,得到n个结果,再对n个结果做softmax。我们之所以能一次得到n个结果,是巧妙地把n放到了batch那一维。
1 2 3 4 5 6 7 8 9 10 11 for i_output in range (n_output): repeat_s = prev_s.repeat(1 , n_squence, 1 ) attention_input = torch.cat((repeat_s, a), 2 ).reshape(batch * n_squence, -1 ) x = self.attention_linear(attention_input) x = x.reshape(batch, n_squence) alpha = self.softmax(x)
求出了注意力alpha
后,就可以用它来算出注意力上下文c
了。
1 c = torch.sum (a * alpha.reshape(batch, n_squence, 1 ), 1 )
之后,我们把c
和上一轮输出prev_y
拼一下,作为解码器的输出。1 2 c = c.unsqueeze(1 ) decoder_input = torch.cat((prev_y, c), 2 )
再调用解码器即可。这里我利用PyTorch的机制偷了个懒。理论上解码器第一轮的状态应该是全零张量,我们应该初始化两个全零张量作为LSTM的初始状态。但是,在PyTorch里,如果调用RNN时不传入状态,就默认会使用全零状态。因此,在第一轮调用时,我们可以不去传状态参数。
1 2 3 4 if tmp_states is None : prev_s, tmp_states = self.decoder(decoder_input) else : prev_s, tmp_states = self.decoder(decoder_input, tmp_states)
最后,用线性层算出这轮的输出,维护输出变量y
。循环结束后,返回y
。
1 2 3 prev_y = self.output_linear(prev_s) y[:, i_output] = prev_y.squeeze(1 ) return y
训练、测试、推理 写完了最核心的注意力模型,剩下的代码就比较简单了。
首先,我们要准备一个Dataset
类。这个类可以读取输入、输出字符串,并把它们转换成整形数组。字符和整形数字间的映射非常暴力,一个字符的序号就是该字符的ASCII码。这样写比较简洁,但由于很多字符是用不到的,会浪费一些计算性能。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def stoi (str ): return torch.LongTensor([ord (char) for char in str ]) def itos (arr ): return '' .join([chr (x) for x in arr]) class DateDataset (Dataset ): def __init__ (self, lines ): self.lines = lines def __len__ (self ): return len (self.lines) def __getitem__ (self, index ): line = self.lines[index] return stoi(line[0 ]), stoi(line[1 ])
准备好DataSet后,就可以生成DataLoader了。在序列任务中,各个样本的序列长度可能是不一致的。我们可以用PyTorch的pad_sequence
对长度不足的样本进行0填充,使得一个batch里的所有样本都有着同样的序列长度。
1 2 3 4 5 6 7 8 9 10 11 def get_dataloader (filename ): def collate_fn (batch ): x, y = zip (*batch) x_pad = pad_sequence(x, batch_first=True ) y_pad = pad_sequence(y, batch_first=True ) return x_pad, y_pad lines = load_date_data(filename) dataset = DateDataset(lines) return DataLoader(dataset, 32 , collate_fn=collate_fn)
这里要稍微注意一下,pad_sequence
默认会做0填充,0填充在我们的项目里是合理的。在我们定义的“词汇表”里,0对应的是ASCII里的0号字符,这个字符不会和其他字符起冲突。
做好一切准备工作后,可以开始训练模型了。训练模型的代码非常常规,定义好Adam优化器、交叉熵误差,跑完模型后reshape一下算出loss再反向传播即可。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 def main (): device = 'cuda:0' train_dataloader = get_dataloader('dldemos/attention/train.txt' ) test_dataloader = get_dataloader('dldemos/attention/test.txt' ) model = AttentionModel().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001 ) citerion = torch.nn.CrossEntropyLoss() for epoch in range (20 ): loss_sum = 0 dataset_len = len (train_dataloader.dataset) for x, y in train_dataloader: x = x.to(device) y = y.to(device) hat_y = model(x) n, Tx, _ = hat_y.shape hat_y = torch.reshape(hat_y, (n * Tx, -1 )) label_y = torch.reshape(y, (n * Tx, )) loss = citerion(hat_y, label_y) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5 ) optimizer.step() loss_sum += loss * n print (f'Epoch {epoch} . loss: {loss_sum / dataset_len} ' ) torch.save(model.state_dict(), 'dldemos/attention/model.pth' )
训练完模型后,我们可以测试一下模型在测试集上的正确率。在日期翻译任务中,我们可以把“正确”定义为输出和真值一模一样。比如一条日期的真值是”2000-01-01”,模型的输出必须也是”2000-01-01”才能说这个输出是正确的。编写并行化计算正确率的代码稍有难度。
模型的输出hat_y
表示各个字符的出现概率。我们先用prediction = torch.argmax(hat_y, 2)
把序列里每个概率最大的字符作为模型预测的字符。现在,我们要用并行化编程判断每对序列(整形标签数组)predition[i]
和y[i]
是否相等(注意,predition
和y
是带了batch那个维度的)。这里,我们可以让predition[i]
和y[i]
做减法再求和。仅当这个和为0时,我们才能说predition[i]
和y[i]
完全相等。通过这样一种曲折的实现方法,我们可以并行地算出正确率。
也许有更方便的API可以完成这个逻辑判断,但去网上搜索这么复杂的一个需求太麻烦了,我偷了个懒。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 model.load_state_dict(torch.load('dldemos/attention/model.pth' )) accuracy = 0 dataset_len = len (test_dataloader.dataset) for x, y in test_dataloader: x = x.to(device) y = y.to(device) hat_y = model(x) prediction = torch.argmax(hat_y, 2 ) score = torch.where(torch.sum (prediction - y, -1 ) == 0 , 1 , 0 ) accuracy += torch.sum (score) print (f'Accuracy: {accuracy / dataset_len} ' )
最后,我们也可以临时生成几个测试用例,输出模型的预测结果。
1 2 3 4 5 6 7 8 9 for _ in range (5 ): x, y = generate_date() origin_x = x x = stoi(x).unsqueeze(0 ).to(device) hat_y = model(x) hat_y = hat_y.squeeze(0 ).argmax(1 ) hat_y = itos(hat_y) print (f'input: {origin_x} , prediction: {hat_y} , gt: {y} ' )
训练20-30个epoch后,模型差不多就收敛了。我训练的模型在测试集上的正确率约有98%。下面是随机测试用例的推理结果,可以看出模型的判断确实很准确。
text 1 2 3 4 5 input: 4 November 1988, prediction: 1988-11-04, gt: 1988-11-04 input: Friday 26, March 2021, prediction: 2021-03-26, gt: 2021-03-26 input: Saturday 2, December 1989, prediction: 1989-12-02, gt: 1989-12-02 input: 15/10/1971, prediction: 1971-10-15, gt: 1971-10-15 input: Mon 9, Oct 1989, prediction: 1989-10-09, gt: 1989-10-09
总结 在这篇文章中,我展示了一个用PyTorch编写的注意力模型,它用于完成日期翻译任务。在这个项目中,最重要的是注意力模型的编写。如今,注意力模型已经不是功能最强大的模型架构了。不过,通过动手实现这个模型,我们可以对注意力机制有着更深刻的认识,有助于理解那些更先进的模型。