GNN教程:DGL框架實(shí)現(xiàn)GCN算法!
引言
本文為GNN教程的第七篇文章【使用DGL框架實(shí)現(xiàn)GCN算法】。圖神經(jīng)網(wǎng)絡(luò)的計算模式大致相似,節(jié)點(diǎn)的Embedding需要匯聚其鄰接節(jié)點(diǎn)Embedding以更新,從線性代數(shù)的角度來看,這就是鄰接矩陣和特征矩陣相乘。然而鄰接矩陣通常都會很大,因此另一種計算方法是將鄰居的Embedding傳遞到當(dāng)前節(jié)點(diǎn)上,再進(jìn)行更新。很多圖并行框架都采用詳細(xì)傳遞的機(jī)制進(jìn)行運(yùn)算(比如Google的Pregel)。而圖神經(jīng)網(wǎng)絡(luò)框架DGL也采用了這樣的思路。
后臺回復(fù)【GNN】進(jìn)圖神經(jīng)網(wǎng)絡(luò)交流群。
從本篇博文開始,我們使用DGL做一個系統(tǒng)的介紹,我們主要關(guān)注他的設(shè)計,尤其是應(yīng)對大規(guī)模圖計算的設(shè)計。這篇文章將會介紹DGL的核心概念 — 消息傳遞機(jī)制,并且使用DGL框架實(shí)現(xiàn)GCN算法。

DGL 核心 — 消息傳遞
DGL 的核心為消息傳遞機(jī)制(message passing),主要分為消息函數(shù) (message function)和匯聚函數(shù)(reduce function)。如下圖所示:

消息函數(shù)(message function):傳遞消息的目的是將節(jié)點(diǎn)計算時需要的信息傳遞給它,因此對每條邊來說,每個源節(jié)點(diǎn)將會將自身的Embedding(e.src.data)和邊的Embedding(edge.data)傳遞到目的節(jié)點(diǎn);對于每個目的節(jié)點(diǎn)來說,它可能會受到多個源節(jié)點(diǎn)傳過來的消息,它會將這些消息存儲在"郵箱"中。 匯聚函數(shù)(reduce function):匯聚函數(shù)的目的是根據(jù)鄰居傳過來的消息更新跟新自身節(jié)點(diǎn)Embedding,對每個節(jié)點(diǎn)來說,它先從郵箱(v.mailbox['m'])中匯聚消息函數(shù)所傳遞過來的消息(message),并清空郵箱(v.mailbox['m'])內(nèi)消息;然后該節(jié)點(diǎn)結(jié)合匯聚后的結(jié)果和該節(jié)點(diǎn)原Embedding,更新節(jié)點(diǎn)Embedding。
下面我們以GCN的算法為例,詳細(xì)說明消息傳遞的機(jī)制是如何work的。
用消息傳遞的方式實(shí)現(xiàn)GCN
GCN 的線性代數(shù)表達(dá)
GCN 的逐層傳播公式如下所示:
從線性代數(shù)的角度,節(jié)點(diǎn)Embedding的的更新方式為首先左乘鄰接矩陣以匯聚鄰居Embedding,再為新Embedding做一次線性變換(右乘)。
簡而言之:每個節(jié)點(diǎn)拿到鄰居節(jié)點(diǎn)信息匯聚到自身 embedding 上在進(jìn)行一次變換。具體 GCN 內(nèi)容介紹可參考之前的文章
從消息傳遞的角度分析
上面的數(shù)學(xué)描述可以利用消息傳遞的機(jī)制實(shí)現(xiàn)為:
在 GCN 中每個節(jié)點(diǎn)都有屬于自己的表示?; 根據(jù)消息傳遞(message passing)的范式,每個節(jié)點(diǎn)將會收到來自鄰居節(jié)點(diǎn)發(fā)送的Embedding; 每個節(jié)點(diǎn)將會對來自鄰居節(jié)點(diǎn)的 Embedding進(jìn)行匯聚以得到中間表示? ; 對中間節(jié)點(diǎn)表示? 進(jìn)行線性變換,然后在利用非線性函數(shù)進(jìn)行計算:; 利用新的節(jié)點(diǎn)表示? 對該節(jié)點(diǎn)的表示?進(jìn)行更新。
具體實(shí)現(xiàn)
step 1,引入相關(guān)包
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
step 2,我們需要定義 GCN 的 message 函數(shù)和 reduce 函數(shù), message 函數(shù)用于發(fā)送節(jié)點(diǎn)的Embedding,reduce 函數(shù)用來對收到的 Embedding 進(jìn)行聚合。在這里,每個節(jié)點(diǎn)發(fā)送Embedding的時候不需要任何處理,所以可以通過內(nèi)置的copy_scr實(shí)現(xiàn),out='m'表示發(fā)送到目的節(jié)點(diǎn)后目的節(jié)點(diǎn)的mailbox用m來標(biāo)識這個消息是源節(jié)點(diǎn)的Embedding。
目的節(jié)點(diǎn)的reduce函數(shù)很簡單,因?yàn)榘凑誈CN的數(shù)學(xué)定義,鄰接矩陣和特征矩陣相乘,以為這更新后的特征矩陣的每一行是原特征矩陣某幾行相加的形式,"某幾行"是由鄰接矩陣選定的,即對應(yīng)節(jié)點(diǎn)的鄰居所在的行。因此目的節(jié)點(diǎn)reduce只需要通過sum將接受到的信息相加就可以了。
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')
step 3,我們定義一個應(yīng)用于節(jié)點(diǎn)的 node UDF(user defined function),即定義一個全連接層(fully-connected layer)來對中間節(jié)點(diǎn)表示? 進(jìn)行線性變換,然后在利用非線性函數(shù)進(jìn)行計算:。
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h' : h}
step 4,我們定義 GCN 的Embedding更新層,以實(shí)現(xiàn)在所有節(jié)點(diǎn)上進(jìn)行消息傳遞,并利用 NodeApplyModule 對節(jié)點(diǎn)信息進(jìn)行計算更新。
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
step 5,最后,我們定義了一個包含兩個 GCN 層的圖神經(jīng)網(wǎng)絡(luò)分類器。我們通過向該分類器輸入特征大小為 1433 的訓(xùn)練樣本,以獲得該樣本所屬的類別編號,類別總共包含 7 類。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.gcn1 = GCN(1433, 16, F.relu)
self.gcn2 = GCN(16, 7, F.relu)
def forward(self, g, features):
x = self.gcn1(g, features)
x = self.gcn2(g, x)
return x
net = Net()
print(net)
step 6,加載 cora 數(shù)據(jù)集,并進(jìn)行數(shù)據(jù)預(yù)處理。
from dgl.data import citation_graph as citegrh
def load_cora_data():
data = citegrh.load_cora()
features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels)
mask = th.ByteTensor(data.train_mask)
g = data.graph
# add self loop
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, mask
step 7,訓(xùn)練 GCN 神經(jīng)網(wǎng)絡(luò)。
import time
import numpy as np
g, features, labels, mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(30):
if epoch >=3:
t0 = time.time()
logits = net(g, features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >=3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)))
后話
本篇博文介紹了如何利用圖神經(jīng)網(wǎng)絡(luò)框架DGL編寫GCN模型,接下來我們會介紹如何利用DGL實(shí)現(xiàn)GraphSAGE中的采樣機(jī)制,以減少運(yùn)算規(guī)模。
Reference
DGL Basics Graph Convolutional Network PageRank with DGL Message Passing DGL 作者答疑!關(guān)于 DGL 你想知道的都在這里
