【深度學(xué)習(xí)】圖解自注意力機(jī)制(Self-Attention)
共 5117字,需瀏覽 11分鐘
·
2024-04-25 12:00
一、注意力機(jī)制和自注意力機(jī)制的區(qū)別
Attention機(jī)制與Self-Attention機(jī)制的區(qū)別
傳統(tǒng)的Attention機(jī)制發(fā)生在Target的元素和Source中的所有元素之間。
簡單講就是說Attention機(jī)制中的權(quán)重的計(jì)算需要Target來參與。即在Encoder-Decoder 模型中,Attention權(quán)值的計(jì)算不僅需要Encoder中的隱狀態(tài)而且還需要Decoder中的隱狀態(tài)。
Self-Attention:
不是輸入語句和輸出語句之間的Attention機(jī)制,而是輸入語句內(nèi)部元素之間或者輸出語句內(nèi)部元素之間發(fā)生的Attention機(jī)制。
例如在Transformer中在計(jì)算權(quán)重參數(shù)時(shí),將文字向量轉(zhuǎn)成對應(yīng)的KQV,只需要在Source處進(jìn)行對應(yīng)的矩陣操作,用不到Target中的信息。
二、引入自注意力機(jī)制的目的
神經(jīng)網(wǎng)絡(luò)接收的輸入是很多大小不一的向量,并且不同向量向量之間有一定的關(guān)系,但是實(shí)際訓(xùn)練的時(shí)候無法充分發(fā)揮這些輸入之間的關(guān)系而導(dǎo)致模型訓(xùn)練結(jié)果效果極差。比如機(jī)器翻譯問題(序列到序列的問題,機(jī)器自己決定多少個(gè)標(biāo)簽),詞性標(biāo)注問題(一個(gè)向量對應(yīng)一個(gè)標(biāo)簽),語義分析問題(多個(gè)向量對應(yīng)一個(gè)標(biāo)簽)等文字處理問題。
針對全連接神經(jīng)網(wǎng)絡(luò)對于多個(gè)相關(guān)的輸入無法建立起相關(guān)性的這個(gè)問題,通過自注意力機(jī)制來解決,自注意力機(jī)制實(shí)際上是想讓機(jī)器注意到整個(gè)輸入中不同部分之間的相關(guān)性。
三、Self-Attention詳解
針對輸入是一組向量,輸出也是一組向量,輸入長度為N(N可變化)的向量,輸出同樣為長度為N 的向量。
3.1 單個(gè)輸出
對于每一個(gè)輸入向量a,經(jīng)過藍(lán)色部分self-attention之后都輸出一個(gè)向量b,這個(gè)向量b是考慮了所有的輸入向量對a1產(chǎn)生的影響才得到的,這里有四個(gè)詞向量a對應(yīng)就會輸出四個(gè)向量b。
下面以b1的輸出為例
首先,如何計(jì)算sequence中各向量與a1的關(guān)聯(lián)程度,有下面兩種方法
Dot-product方法是將兩個(gè)向量乘上不同的矩陣w,得到q和k,做點(diǎn)積得到α,transformer中就用到了Dot-product。
上圖中綠色的部分就是輸入向量a1和a2,灰色的Wq和Wk為權(quán)重矩陣,需要學(xué)習(xí)來更新,用a1去和Wq相乘,得到一個(gè)向量q,然后使用a2和Wk相乘,得到一個(gè)數(shù)值k。最后使用q和k做點(diǎn)積,得到α。α也就是表示兩個(gè)向量之間的相關(guān)聯(lián)程度。
上圖右邊加性模型這種機(jī)制也是輸入向量與權(quán)重矩陣相乘,后相加,然后使用tanh投射到一個(gè)新的函數(shù)空間內(nèi),再與權(quán)重矩陣相乘,得到最后的結(jié)果。
可以計(jì)算每一個(gè)α(又稱為attention score),q稱為query,k稱為key
另外,也可以計(jì)算a1和自己的關(guān)聯(lián)性,再得到各向量與a1的相關(guān)程度之后,用softmax計(jì)算出一個(gè)attention distribution,這樣就把相關(guān)程度歸一化,通過數(shù)值就可以看出哪些向量是和a1最有關(guān)系。
下面需要根據(jù) α′ 抽取sequence里重要的資訊:
先求v,v就是鍵值value,v和q、k計(jì)算方式相同,也是用輸入a乘以權(quán)重矩陣W,得到v后,與對應(yīng)的α′ 相乘,每一個(gè)v乘與α'后求和,得到輸出b1。
如果 a1 和 a2 關(guān)聯(lián)性比較高, α1,2′ 就比較大,那么,得到的輸出 b1 就可能比較接近 v2 ,即attention score決定了該vector在結(jié)果中占的分量;
3.2 矩陣形式
寫成矩陣形式:
把4個(gè)輸入a拼成一個(gè)矩陣,這個(gè)矩陣有4個(gè)column,也就是a1到a4,
乘上相應(yīng)的權(quán)重矩陣W,得到相應(yīng)的矩陣Q、K、V,分別表示query,key和value。
三個(gè)W是我們需要學(xué)習(xí)的參數(shù)
同樣,q1到q4也可以拼接成矩陣Q直接與矩陣K相乘:
公式為:
矩陣A中的每一個(gè)值記錄了對應(yīng)的兩個(gè)輸入向量的Attention的大小α,A'是經(jīng)過softmax歸一化后的矩陣。
寫成矩陣形式:
對self-attention操作過程做個(gè)總結(jié),輸入是I,輸出是O:
矩陣Wq、 Wk 、Wv是需要學(xué)習(xí)的參數(shù)。
四、Multi-head Self-attention
self-attention的進(jìn)階版本Multi-head Self-attention,多頭自注意力機(jī)制
因?yàn)橄嚓P(guān)性有很多種不同的形式,有很多種不同的定義,所以有時(shí)不能只有一個(gè)q,要有多個(gè)q,不同的q負(fù)責(zé)不同種類的相關(guān)性。
對于1個(gè)輸入a
首先,和上面一樣,用a乘權(quán)重矩陣W得到,然后再用乘兩個(gè)不同的W,得到兩個(gè)不同的,i代表的是位置,1和2代表的是這個(gè)位置的第幾個(gè)q。
這上面這個(gè)圖中,有兩個(gè)head,代表這個(gè)問題有兩種不同的相關(guān)性。
同樣,k和v也需要有多個(gè),兩個(gè)k、v的計(jì)算方式和q相同,都是先算出來ki和vi,然后再乘兩個(gè)不同的權(quán)重矩陣。
對于多個(gè)輸入向量也一樣,每個(gè)向量都有多個(gè)head:
算出來q、k、v之后怎么做self-attention呢?
和上面講的過程一樣,只不過是1那類的一起做,2那類的一起做,兩個(gè)獨(dú)立的過程,算出來兩個(gè)b。
對于1:
對于2:
這只是兩個(gè)head的例子,有多個(gè)head過程也一樣,都是分開算b。
五、Positional Encoding
在訓(xùn)練self attention的時(shí)候,實(shí)際上對于位置的信息是缺失的,沒有前后的區(qū)別,上面講的a1,a2,a3不代表輸入的順序,只是指輸入的向量數(shù)量,不像rnn,對于輸入有明顯的前后順序,比如在翻譯任務(wù)里面,對于“機(jī)器學(xué)習(xí)”,機(jī)器學(xué)習(xí)依次輸入。而self-attention的輸入是同時(shí)輸入,輸出也是同時(shí)產(chǎn)生然后輸出的。
如何在Self-Attention里面體現(xiàn)位置信息呢?就是使用Positional Encoding
如果ai加上了ei,就會體現(xiàn)出位置的信息,i是多少,位置就是多少。
vector長度是人為設(shè)定的,也可以從數(shù)據(jù)中訓(xùn)練出來。
六、Self-Attention和RNN的區(qū)別
Self-attention和RNN的主要區(qū)別在于:
1.Self-attention可以考慮全部的輸入,而RNN似乎只能考慮之前的輸入(左邊)。但是當(dāng)使用雙向RNN的時(shí)候可以避免這一問題。
2.Self-attention可以容易地考慮比較久之前的輸入,而RNN的最早輸入由于經(jīng)過了很多層網(wǎng)絡(luò)的處理變得較難考慮。
3.Self-attention可以并行計(jì)算,而RNN不同層之間具有先后順序。
1.Self-attention可以考慮全部的輸入,而RNN似乎只能考慮之前的輸入(左邊)。但是當(dāng)使用雙向RNN的時(shí)候可以避免這一問題。
比如,對于第一個(gè)RNN,只考慮了深藍(lán)色的輸入,綠色及綠色后面的輸入不會考慮,而Self-Attention對于4個(gè)輸入全部考慮
2.Self-attention可以容易地考慮比較久之前的輸入,而RNN的最早輸入由于經(jīng)過了很多層網(wǎng)絡(luò)的處理變得較難考慮。
比如對于最后一個(gè)RNN的黃色輸出,想要包含最開始的藍(lán)色輸入,必須保證藍(lán)色輸入在經(jīng)過每層時(shí)信息都不丟失,但如果一個(gè)sequence很長,就很難保證。而Self-attention每個(gè)輸出都和所有輸入直接有關(guān)。
3.Self-attention可以并行計(jì)算,而RNN不同層之間具有先后順序。
Self-attention的輸入是同時(shí)輸入,輸出也是同時(shí)輸出。
往期精彩回顧
交流群
歡迎加入機(jī)器學(xué)習(xí)愛好者微信群一起和同行交流,目前有機(jī)器學(xué)習(xí)交流群、博士群、博士申報(bào)交流、CV、NLP等微信群,請掃描下面的微信號加群,備注:”昵稱-學(xué)校/公司-研究方向“,例如:”張小明-浙大-CV“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~(也可以加入機(jī)器學(xué)習(xí)交流qq群772479961)
