GNN教程:GraghSAGE算法細(xì)節(jié)詳解!
引言


一、Inductive learning v.s. Transductive learning
首先我們介紹一下什么是inductive learning。與其他類型的數(shù)據(jù)不同,圖數(shù)據(jù)中的每一個節(jié)點可以通過邊的關(guān)系利用其他節(jié)點的信息,這樣就產(chǎn)生了一個問題,如果訓(xùn)練集上的節(jié)點通過邊關(guān)聯(lián)到了預(yù)測集或者驗證集的節(jié)點,那么在訓(xùn)練的時候能否用它們的信息呢? 如果訓(xùn)練時用到了測試集或驗證集樣本的信息(或者說,測試集和驗證集在訓(xùn)練的時候是可見的), 我們把這種學(xué)習(xí)方式叫做transductive learning, 反之,稱為inductive learning。
顯然,我們所處理的大多數(shù)機器學(xué)習(xí)問題都是inductive learning, 因為我們刻意的將樣本集分為訓(xùn)練/驗證/測試,并且訓(xùn)練的時候只用訓(xùn)練樣本。然而,在GCN中,訓(xùn)練節(jié)點收集鄰居信息的時候,用到了測試或者驗證樣本,所以它是transductive的。
二、概述
GraphSAGE是一個inductive框架,在具體實現(xiàn)中,訓(xùn)練時它僅僅保留訓(xùn)練樣本到訓(xùn)練樣本的邊。inductive learning 的優(yōu)點是可以利用已知節(jié)點的信息為未知節(jié)點生成Embedding. GraphSAGE 取自 Graph SAmple and aggreGatE, SAmple指如何對鄰居個數(shù)進(jìn)行采樣。aggreGatE指拿到鄰居的embedding之后如何匯聚這些embedding以更新自己的embedding信息。下圖展示了GraphSAGE學(xué)習(xí)的一個過程:

對鄰居采樣 采樣后的鄰居embedding傳到節(jié)點上來,并使用一個聚合函數(shù)聚合這些鄰居信息以更新節(jié)點的embedding 根據(jù)更新后的embedding預(yù)測節(jié)點的標(biāo)簽
三、算法細(xì)節(jié)
3.1 節(jié)點 Embedding 生成(即:前向傳播)算法
這一節(jié)討論的是如何給圖中的節(jié)點生成(或者說更新)embedding, 假設(shè)我們已經(jīng)完成了GraphSAGE的訓(xùn)練,因此模型所有的參數(shù)(parameters)都已知了。具體來說,這些參數(shù)包括個聚合器(見下圖算法第4行)中的參數(shù), 這些聚合器被用來將鄰居embedding信息聚合到節(jié)點上,以及一系列的權(quán)重矩陣(下圖算法第5行), 這些權(quán)值矩陣被用作在模型層與層之間傳播embedding的時候做非線性變換。
下面的算法描述了我們是怎么做前向傳播的:

算法的主要部分為:
(line 1)初始化每個節(jié)點embedding為節(jié)點的特征向量 (line 3)對于每一個節(jié)點 (line 4)拿到它采樣后的鄰居的embedding并將其聚合,這里表示對鄰居采樣 (line 5)根據(jù)聚合后的鄰居embedding()和自身embedding()通過一個非線性變換()更新自身embedding.
算法里的這個比較難理解,下面單獨來說他,之前提到過,它既是聚合器的數(shù)量,也是權(quán)重矩陣的數(shù)量,還是網(wǎng)絡(luò)的層數(shù),這是因為每一層網(wǎng)絡(luò)中聚合器和權(quán)重矩陣是共享的。
網(wǎng)絡(luò)的層數(shù)可以理解為需要最大訪問到的鄰居的跳數(shù)(hops),比如在figure 1中,紅色節(jié)點的更新拿到了它一、二跳鄰居的信息,那么網(wǎng)絡(luò)層數(shù)就是2。
為了更新紅色節(jié)點,首先在第一層()我們會將藍(lán)色節(jié)點的信息聚合到紅色節(jié)點上,將綠色節(jié)點的信息聚合到藍(lán)色節(jié)點上。在第二層()紅色節(jié)點的embedding被再次更新,不過這次用的是更新后的藍(lán)色節(jié)點embedding,這樣就保證了紅色節(jié)點更新后的embedding包括藍(lán)色和綠色節(jié)點的信息。
3.2 采樣 (SAmple) 算法
GraphSAGE采用了定長抽樣的方法,具體來說,定義需要的鄰居個數(shù), 然后采用有放回的重采樣/負(fù)采樣方法達(dá)到,。保證每個節(jié)點(采樣后的)鄰居個數(shù)一致是為了把多個節(jié)點以及他們的鄰居拼成Tensor送到GPU中進(jìn)行批訓(xùn)練。
3.3 聚合器 (Aggregator) 架構(gòu)
GraphSAGE 提供了多種聚合器,實驗中效果最好的平均聚合器(mean aggregator),平均聚合器的思慮很簡單,每個維度取對鄰居embedding相應(yīng)維度的均值,這個和GCN的做法基本一致(GCN實際上用的是求和):
舉個簡單例子,比如一個節(jié)點的3個鄰居的embedding分別為? ,按照每一維分別求均值就得到了聚合后的鄰居embedding為?.
論文中還闡述了另外兩種aggregator: LSTM aggregator 和 Pooling aggregator, 有興趣的可以去論文中看下。
3.4 參數(shù)學(xué)習(xí)
到此為止,整個模型的架構(gòu)就講完了,那么GraphSAGE是如何學(xué)習(xí)聚合器的參數(shù)以及權(quán)重變量的呢? 在有監(jiān)督的情況下,可以使用每個節(jié)點的預(yù)測label和真實label的交叉熵作為損失函數(shù)。在無監(jiān)督的情況下,可以假設(shè)相鄰的節(jié)點的輸出embeding應(yīng)當(dāng)盡可能相近,因此可以設(shè)計出如下的損失函數(shù):
其中是節(jié)點的輸出embedding,?是節(jié)點的鄰居(這里鄰居是廣義的,比如說如果和在一個定長的隨機游走中可達(dá),那么我們也認(rèn)為他們相鄰),是負(fù)采樣分布,是負(fù)采樣的樣本數(shù)量,所謂負(fù)采樣指我們還需要一批不是鄰居的節(jié)點作為負(fù)樣本,那么上面這個式子的意思是相鄰節(jié)點的embedding的相似度盡量大的情況下保證不相鄰節(jié)點的embedding的期望相似度盡可能小。
四、后話
GraphSAGE采用了采樣的機制,克服了GCN訓(xùn)練時內(nèi)存和顯存上的限制,使得圖模型可以應(yīng)用到大規(guī)模的圖結(jié)構(gòu)數(shù)據(jù)中,是目前幾乎所有工業(yè)上圖模型的雛形。然而,每個節(jié)點這么多鄰居,采樣能否考慮到鄰居的相對重要性呢,或者我們在聚合計算中能否考慮到鄰居的相對重要性? 這個問題在我們的下一篇博文Graph Attentioin Networks中做了詳細(xì)的討論。
Reference
[1] Inductive Representation Learning on Large Graphs(http://arxiv.org/abs/1706.02216)
