tf_geometricTensorFlow 圖神經(jīng)網(wǎng)絡(luò)框架
tf_geometric是一個(gè)高效且友好的圖神經(jīng)網(wǎng)絡(luò)庫(kù),同時(shí)支持TensorFlow 1.x 和 2.x。
受到rusty1s/pytorch_geometric項(xiàng)目的啟發(fā),我們?yōu)門ensorFlow構(gòu)建了一個(gè)圖神經(jīng)網(wǎng)絡(luò)(GNN)庫(kù)。 tf_geometric 同時(shí)提供面向?qū)ο蠼涌冢∣OP API)和函數(shù)式接口(Functional API),你可以用它們來構(gòu)建有趣的模型。
- Github主頁(yè): https://github.com/CrawlScript/tf_geometric
- 開發(fā)文檔: https://tf-geometric.readthedocs.io
- 論文: Efficient Graph Deep Learning in TensorFlow with tf_geometric
高效且友好的API
tf_geometric使用消息傳遞機(jī)制來實(shí)現(xiàn)圖神經(jīng)網(wǎng)絡(luò):相比于基于稠密矩陣的實(shí)現(xiàn),它具有更高的效率;相比于基于稀疏矩陣的實(shí)現(xiàn),它具有更友好的API。 除此之外,tf_geometric還為復(fù)雜的圖神經(jīng)網(wǎng)絡(luò)操作提供了簡(jiǎn)易優(yōu)雅的API。 下面的示例展現(xiàn)了使用tf_geometric構(gòu)建一個(gè)圖結(jié)構(gòu)的數(shù)據(jù),并使用多頭圖注意力網(wǎng)絡(luò)(Multi-head GAT)對(duì)圖數(shù)據(jù)進(jìn)行處理的流程:
# coding=utf-8 import numpy as np import tf_geometric as tfg import tensorflow as tf graph = tfg.Graph( x=np.random.randn(5, 20), # 5個(gè)節(jié)點(diǎn), 20維特征 edge_index=[[0, 0, 1, 3], [1, 2, 2, 1]] # 4個(gè)無向邊 ) print("Graph Desc: \n", graph) graph.convert_edge_to_directed() # 預(yù)處理邊數(shù)據(jù),將無向邊表示轉(zhuǎn)換為有向邊表示 print("Processed Graph Desc: \n", graph) print("Processed Edge Index:\n", graph.edge_index) # 多頭圖注意力網(wǎng)絡(luò)(Multi-head GAT) gat_layer = tfg.layers.GAT(units=4, num_heads=4, activation=tf.nn.relu) output = gat_layer([graph.x, graph.edge_index]) print("Output of GAT: \n", output)
輸出:
Graph Desc: Graph Shape: x => (5, 20) edge_index => (2, 4) y => None Processed Graph Desc: Graph Shape: x => (5, 20) edge_index => (2, 8) y => None Processed Edge Index: [[0 0 1 1 1 2 2 3] [1 2 0 2 3 0 1 1]] Output of GAT: tf.Tensor( [[0.22443159 0. 0.58263206 0.32468423] [0.29810357 0. 0.19403605 0.35630274] [0.18071976 0. 0.58263206 0.32468423] [0.36123228 0. 0.88897204 0.450244 ] [0. 0. 0.8013462 0. ]], shape=(5, 4), dtype=float32)
入門教程
教程列表
使用示例進(jìn)行快速入門
強(qiáng)烈建議您通過下面的示例代碼來快速入門tf_geometric:
節(jié)點(diǎn)分類
- 圖卷積網(wǎng)絡(luò) Graph Convolutional Network (GCN)
- 多頭圖注意力網(wǎng)絡(luò) Multi-head Graph Attention Network (GAT)
- Approximate Personalized Propagation of Neural Predictions (APPNP)
- Inductive Representation Learning on Large Graphs (GraphSAGE)
- 切比雪夫網(wǎng)絡(luò) Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering (ChebyNet)
- Simple Graph Convolution (SGC)
- Topology Adaptive Graph Convolutional Network (TAGCN)
- Deep Graph Infomax (DGI)
- DropEdge: Towards Deep Graph Convolutional Networks on Node Classification (DropEdge)
- 基于圖卷積網(wǎng)絡(luò)的文本分類 Graph Convolutional Networks for Text Classification (TextGCN)
圖分類
- 平均池化 MeanPooling
- Graph Isomorphism Network (GIN)
- 自注意力圖池化 Self-Attention Graph Pooling (SAGPooling)
- 可微池化 Hierarchical Graph Representation Learning with Differentiable Pooling (DiffPool)
- Order Matters: Sequence to Sequence for Sets (Set2Set)
- ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations (ASAP)
- An End-to-End Deep Learning Architecture for Graph Classification (SortPool)
- 最小割池化 Spectral Clustering with Graph Neural Networks for Graph Pooling (MinCutPool)
鏈接預(yù)測(cè)
分布式訓(xùn)練
引用
如果您在科研出版物中使用了tf_geometric,歡迎引用下方的論文:
@misc{hu2021efficient,
title={Efficient Graph Deep Learning in TensorFlow with tf_geometric},
author={Jun Hu and Shengsheng Qian and Quan Fang and Youze Wang and Quan Zhao and Huaiwen Zhang and Changsheng Xu},
year={2021},
eprint={2101.11552},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
評(píng)論
圖片
表情
