圖解 Attention(完整版)!
本文約4000字,建議閱讀11分鐘
審稿人:Jepson,Datawhale成員,畢業(yè)于中國(guó)科學(xué)院,目前在騰訊從事推薦算法工作。
序列到序列(seq2seq)模型是一種深度學(xué)習(xí)模型,在很多任務(wù)上都取得了成功,如:機(jī)器翻譯、文本摘要、圖像描述生成。谷歌翻譯在 2016 年年末開(kāi)始使用這種模型。有2篇開(kāi)創(chuàng)性的論文:
Sutskever等2014年發(fā)布的:https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf,
Cho等2014年發(fā)布的:http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf
都對(duì)這些模型進(jìn)行了解釋。
然而,我發(fā)現(xiàn),想要充分理解模型并實(shí)現(xiàn)它,需要拆解一系列概念,而這些概念是層層遞進(jìn)的。我認(rèn)為,如果能夠把這些概念進(jìn)行可視化,會(huì)更加容易理解。這就是這篇文章的目標(biāo)。你需要先了解一些深度學(xué)習(xí)的知識(shí),才能讀完這篇文章。我希望這篇文章,可以對(duì)你閱讀上面提到的 2 篇論文有幫助。
一個(gè)序列到序列(seq2seq)模型,接收的輸入是一個(gè)輸入的(單詞、字母、圖像特征)序列,輸出是另外一個(gè)序列。一個(gè)訓(xùn)練好的模型如下圖所示:

在神經(jīng)機(jī)器翻譯中,一個(gè)序列是指一連串的單詞。類(lèi)似地,輸出也是一連串單詞。

進(jìn)一步理解細(xì)節(jié)
模型是由編碼器(Encoder)和解碼器(Decoder)組成的。其中,編碼器會(huì)處理輸入序列中的每個(gè)元素,把這些信息轉(zhuǎn)換為一個(gè)向量(稱(chēng)為上下文(context))。當(dāng)我們處理完整個(gè)輸入序列后,編碼器把上下文(context)發(fā)送給解碼器,解碼器開(kāi)始逐項(xiàng)生成輸出序列中的元素。

這種機(jī)制,同樣適用于機(jī)器翻譯。
在機(jī)器翻譯任務(wù)中,上下文(context)是一個(gè)向量(基本上是一個(gè)數(shù)字?jǐn)?shù)組)。編碼器和解碼器一般都是循環(huán)神經(jīng)網(wǎng)絡(luò),一定要看看 Luis Serrano寫(xiě)的一篇關(guān)于循環(huán)神經(jīng)網(wǎng)絡(luò)(https://www.youtube.com/watch?v=UNmqTiOnRfg)的精彩介紹
你可以在設(shè)置模型的時(shí)候設(shè)置上下文向量的長(zhǎng)度。這個(gè)長(zhǎng)度是基于編碼器 RNN 的隱藏層神經(jīng)元的數(shù)量。上圖展示了長(zhǎng)度為 4 的向量,但在實(shí)際應(yīng)用中,上下文向量的長(zhǎng)度可能是 256,512 或者 1024。
根據(jù)設(shè)計(jì),RNN 在每個(gè)時(shí)間步接受 2 個(gè)輸入:
-
輸入序列中的一個(gè)元素(在解碼器的例子中,輸入是指句子中的一個(gè)單詞) -
一個(gè) hidden state(隱藏層狀態(tài))
然而每個(gè)單詞都需要表示為一個(gè)向量。為了把一個(gè)詞轉(zhuǎn)換為一個(gè)向量,我們使用一類(lèi)稱(chēng)為 "word embedding" 的方法。這類(lèi)方法把單詞轉(zhuǎn)換到一個(gè)向量空間,這種表示能夠捕捉大量的單詞的語(yǔ)義信息(例如,king - man + woman = queen (http://p.migdal.pl/2017/01/06/king-man-woman-queen-why.html))。
現(xiàn)在,我們已經(jīng)介紹完了向量/張量的基礎(chǔ)知識(shí),讓我們回顧一下 RNN 的機(jī)制,并可視化這些 RNN 模型:

RNN 在第 2 個(gè)時(shí)間步,采用第 1 個(gè)時(shí)間步的 hidden state(隱藏層狀態(tài)) 和第 2 個(gè)時(shí)間步的輸入向量,來(lái)得到輸出。在下文,我們會(huì)使用類(lèi)似這種動(dòng)畫(huà),來(lái)描述神經(jīng)機(jī)器翻譯模型里的所有向量。
在下面的可視化圖形中,編碼器和解碼器在每個(gè)時(shí)間步處理輸入,并得到輸出。由于編碼器和解碼器都是 RNN,RNN 會(huì)根據(jù)當(dāng)前時(shí)間步的輸入,和前一個(gè)時(shí)間步的 hidden state(隱藏層狀態(tài)),更新當(dāng)前時(shí)間步的 hidden state(隱藏層狀態(tài))。
讓我們看下編碼器的 hidden state(隱藏層狀態(tài))。注意,最后一個(gè) hidden state(隱藏層狀態(tài))實(shí)際上是我們傳給解碼器的上下文(context)。
解碼器也持有 hidden state(隱藏層狀態(tài)),而且也需要把 hidden state(隱藏層狀態(tài))從一個(gè)時(shí)間步傳遞到下一個(gè)時(shí)間步。我們沒(méi)有在上圖中可視化解碼器的 hidden state,是因?yàn)檫@個(gè)過(guò)程和解碼器是類(lèi)似的,我們現(xiàn)在關(guān)注的是 RNN 的主要處理過(guò)程。現(xiàn)在讓我們用另一種方式來(lái)可視化序列到序列(seq2seq)模型。下面的動(dòng)畫(huà)會(huì)讓我們更加容易理解模型。這種方法稱(chēng)為展開(kāi)視圖。其中,我們不只是顯示一個(gè)解碼器,而是在時(shí)間上展開(kāi),每個(gè)時(shí)間步都顯示一個(gè)解碼器。通過(guò)這種方式,我們可以看到每個(gè)時(shí)間步的輸入和輸出。
Attention 講解
事實(shí)證明,上下文向量是這類(lèi)模型的瓶頸。這使得模型在處理長(zhǎng)文本時(shí)面臨非常大的挑戰(zhàn)。
在 Bahdanau等2014發(fā)布的(https://arxiv.org/abs/1409.0473) 和 Luong等2015年發(fā)布的(https://arxiv.org/abs/1508.04025) 兩篇論文中,提出了一種解決方法。這 2 篇論文提出并改進(jìn)了一種叫做注意力(Attention)的技術(shù),它極大地提高了機(jī)器翻譯的質(zhì)量。注意力使得模型可以根據(jù)需要,關(guān)注到輸入序列的相關(guān)部分。
讓我們繼續(xù)從高層次來(lái)理解注意力模型。一個(gè)注意力模型不同于經(jīng)典的序列到序列(seq2seq)模型,主要體現(xiàn)在 2 個(gè)方面:
首先,編碼器會(huì)把更多的數(shù)據(jù)傳遞給解碼器。編碼器把所有時(shí)間步的 hidden state(隱藏層狀態(tài))傳遞給解碼器,而不是只傳遞最后一個(gè) hidden state(隱藏層狀態(tài))。
第二,注意力模型的解碼器在產(chǎn)生輸出之前,做了一個(gè)額外的處理。為了把注意力集中在與該時(shí)間步相關(guān)的輸入部分。解碼器做了如下的處理:
-
查看所有接收到的編碼器的 hidden state(隱藏層狀態(tài))。其中,編碼器中每個(gè) hidden state(隱藏層狀態(tài))都對(duì)應(yīng)到輸入句子中一個(gè)單詞。 -
給每個(gè) hidden state(隱藏層狀態(tài))一個(gè)分?jǐn)?shù)(我們先忽略這個(gè)分?jǐn)?shù)的計(jì)算過(guò)程)。 -
將每個(gè) hidden state(隱藏層狀態(tài))乘以經(jīng)過(guò) softmax 的對(duì)應(yīng)的分?jǐn)?shù),從而,高分對(duì)應(yīng)的 hidden state(隱藏層狀態(tài))會(huì)被放大,而低分對(duì)應(yīng)的 hidden state(隱藏層狀態(tài))會(huì)被縮小。
這個(gè)加權(quán)平均的步驟是在解碼器的每個(gè)時(shí)間步做的。
現(xiàn)在,讓我們把所有內(nèi)容都融合到下面的圖中,來(lái)看看注意力模型的整個(gè)過(guò)程:
-
注意力模型的解碼器 RNN 的輸入包括:一個(gè)embedding 向量,和一個(gè)初始化好的解碼器 hidden state(隱藏層狀態(tài))。 -
RNN 處理上述的 2 個(gè)輸入,產(chǎn)生一個(gè)輸出和一個(gè)新的 hidden state(隱藏層狀態(tài) h4 向量),其中輸出會(huì)被忽略。 -
注意力的步驟:我們使用編碼器的 hidden state(隱藏層狀態(tài))和 h4 向量來(lái)計(jì)算這個(gè)時(shí)間步的上下文向量(C4)。 -
我們把 h4 和 C4 拼接起來(lái),得到一個(gè)向量。 -
我們把這個(gè)向量輸入一個(gè)前饋神經(jīng)網(wǎng)絡(luò)(這個(gè)網(wǎng)絡(luò)是和整個(gè)模型一起訓(xùn)練的)。 -
前饋神經(jīng)網(wǎng)絡(luò)的輸出的輸出表示這個(gè)時(shí)間步輸出的單詞。 -
在下一個(gè)時(shí)間步重復(fù)這個(gè)步驟。
下圖,我們使用另一種方式來(lái)可視化注意力,看看在每個(gè)解碼的時(shí)間步中關(guān)注輸入句子的哪些部分:
如果你覺(jué)得你準(zhǔn)備好了學(xué)習(xí)注意力機(jī)制的代碼實(shí)現(xiàn),一定要看看基于 TensorFlow 的 神經(jīng)機(jī)器翻譯 (seq2seq) 指南(https://github.com/tensorflow/nmt)
本文經(jīng)原作者 @JayAlammmar(https://twitter.com/JayAlammar) 授權(quán)翻譯,期望你的反饋。
后臺(tái)回復(fù)關(guān)鍵詞【張賢】可進(jìn)NLP交流群,和作者一起學(xué)習(xí)NLP。
本文翻譯自: https://jalammar.github.io/illustrated-bert/
