GNN教程:圖注意力網(wǎng)絡(luò)(GAT)詳解!
引言


圖注意力機制的類型
目前主要有三種注意力機制算法,它們分別是:學(xué)習(xí)注意力權(quán)重(Learn attention weights),基于相似性的注意力(Similarity-based attention),注意力引導(dǎo)的隨機游走(Attention-guided walk)。這三種注意力機制都可以用來生成鄰居的相對重要性,下文會闡述他們之間的差異。
首先我們對“圖注意力機制”做一個數(shù)學(xué)上的定義:
定義(圖注意力機制):給定一個圖中節(jié)點 和的鄰居節(jié)點
(這里的? 和GraphSAGE博文中的? 表示一個意思)。注意力機制被定義為將中每個節(jié)點映射到相關(guān)性得分(relevance score)的函數(shù)
相關(guān)性得分表示該鄰居節(jié)點的相對重要性。滿足:
下面再來看看這三種不同的圖注意力機制的具體細(xì)節(jié)
1. 學(xué)習(xí)注意力權(quán)重
學(xué)習(xí)注意力權(quán)重的方法來自于Velickovic et al. 2018 其核心思想是利用參數(shù)矩陣學(xué)習(xí)節(jié)點和鄰居之間的相對重要性。
給定節(jié)點相應(yīng)的特征(embedding)
節(jié)點和節(jié)點注意力權(quán)重可以通過以下公式計算:
其中, 表示節(jié)點對節(jié)點的相對重要性。在實踐中,可以利用節(jié)點的屬性結(jié)合softmax函數(shù)來計算間的相關(guān)性。比如,GAT 中是這樣計算的:
其中, 表示一個可訓(xùn)練的參數(shù)向量, 用來學(xué)習(xí)節(jié)點和鄰居之間的相對重要性, 也是一個可訓(xùn)練的參數(shù)矩陣,用來對輸入特征做線性變換,表示向量拼接(concate)。

如上圖,對于一個目標(biāo)對象, 表示它和鄰居的相對重要性權(quán)重。可以根據(jù)? 和? 的 embedding? 和? 計算,比如圖中 是由? 共同計算得到的。
2. 基于相似性的注意力
上面這種方法使用一個參數(shù)向量學(xué)習(xí)節(jié)點和鄰居的相對重要性,其實另一個容易想到的點是:既然我們有節(jié)點的特征表示,假設(shè)和節(jié)點自身相像的鄰居節(jié)點更加重要,那么可以通過直接計算之間相似性的方法得到節(jié)點的相對重要性。這種方法稱為基于相似性的注意力機制,比如說論文 TheKumparampil et al. 2018 是這樣計算的:
其中, 表示可訓(xùn)練偏差(bias),函數(shù)用來計算余弦相似度,和上一個方法類似, 是一個可訓(xùn)練的參數(shù)矩陣,用來對輸入特征做線性變換。
這個方法和上一個方法的區(qū)別在于,這個方法顯示地使用函數(shù)計算節(jié)點之間的相似性作為相對重要性權(quán)重,而上一個方法使用可學(xué)習(xí)的參數(shù)學(xué)習(xí)節(jié)點之間的相對重要性。
3. 注意力引導(dǎo)的游走法
前兩種注意力方法主要關(guān)注于選擇相關(guān)的鄰居信息,并將這些信息聚合到節(jié)點的embedding中。第三種注意力的方法的目的不同,我們以Lee et al. 2018 作為例子:
GAM方法在輸入圖進(jìn)行一系列的隨機游走,并且通過RNN對已訪問節(jié)點進(jìn)行編碼,構(gòu)建子圖embedding。時間的RNN隱藏狀態(tài)? 編碼了隨機游走中? 步訪問到的節(jié)點。然后,注意力機制被定義為函數(shù)?,用于將輸入的隱向量映射到一個維向量中,可以通過比較這維向量每一維的數(shù)值確定下一步需要優(yōu)先游走到哪種類型的節(jié)點(假設(shè)一共有種節(jié)點類型)。下圖做了形象的闡述:

如上圖,聚合了長度的隨機游走得到的信息,我們將該信息輸入到排序函數(shù)中,以確定各個鄰居節(jié)點的重要性并用于影響下一步游走。
后話
至此,圖注意力機制就講完了,還有一些細(xì)節(jié)沒有涉及,比如在 GAT論文 中討論了對一個節(jié)點使用多個注意力機制(multi-head attention), 在AGNN論文中分析了注意力機制是否真的有效,詳細(xì)的可以參考原論文。
參考文獻(xiàn)
[1] Attention Models in Graphs: A Survey
[2] Graph Attention Networks
[3] Attention-based Graph Neural Network for Semi-supervised Learning
[4] Graph Classification using Structural Attention
