pytorch 训练过程acc_【图节点分类】10分钟就学会的图节点分类教程,基于pytorch和dgl...
圖神經(jīng)網(wǎng)絡(luò)中最流行和廣泛采用的任務(wù)之一就是節(jié)點(diǎn)分類,其中訓(xùn)練集/驗(yàn)證集/測(cè)試集中的每個(gè)節(jié)點(diǎn)從一組預(yù)定義的類別中分配一個(gè)真實(shí)類別。
為了對(duì)節(jié)點(diǎn)進(jìn)行分類,圖神經(jīng)網(wǎng)絡(luò)利用節(jié)點(diǎn)自身的特征,以及相鄰節(jié)點(diǎn)和邊的特征進(jìn)行消息傳遞。消息傳遞可以重復(fù)多次,以聚合來(lái)自更大范圍的鄰居節(jié)點(diǎn)的信息。
dgl框架為我們提供了一些內(nèi)置的圖卷積模塊,可以執(zhí)行一輪的消息傳遞。
在本文中,我們使用dgl.nn.pytorch的SAGEConv模塊,該模塊來(lái)自這篇論文GraphSAGE:Inductive Representation Learning on Large Graphs
通常對(duì)于圖上的深度學(xué)習(xí)模型,我們需要一個(gè)多層圖神經(jīng)網(wǎng)絡(luò),在這里我們進(jìn)行多輪的消息傳遞。這可以通過如下方式堆疊圖卷積模塊來(lái)實(shí)現(xiàn)。
1 構(gòu)造GNN模型
先導(dǎo)入必要包(本文dgl 版本為 0.5.2)
import dgl.nn as dglnn
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import * 構(gòu)造一個(gè)兩層的gnn模型
class SAGE(nn.Module):def __init__(self, in_feats, hid_feats, out_feats, dropout=0.2):super().__init__()self.conv1 = dglnn.SAGEConv( in_feats=in_feats, out_feats=hid_feats, feat_drop=0.2, aggregator_type='gcn')self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, feat_drop=0.2, aggregator_type='mean')self.dropout = nn.Dropout(dropout)def forward(self, graph, inputs):# inputs 是節(jié)點(diǎn)的特征 [N, in_feas]h = self.conv1(graph, inputs)h = self.dropout(F.relu(h))h = self.conv2(graph, h)return h 注意,我們不僅可以將上面的模型用于節(jié)點(diǎn)分類,還可以獲取節(jié)點(diǎn)的特征表示為了其他下游任務(wù),如邊分類/回歸、鏈接預(yù)測(cè)或圖分類。
2 數(shù)據(jù)集與數(shù)據(jù)分析
dataset = CoraGraphDataset() # Cora citation network dataset
graph = dataset[0]
graph = dgl.remove_self_loop(graph) # 消除自環(huán)
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1) print("圖的節(jié)點(diǎn)數(shù)和邊數(shù): ", graph.num_nodes(), graph.num_edges())
print("訓(xùn)練集節(jié)點(diǎn)數(shù):", train_mask.sum().item())
print("驗(yàn)證集集節(jié)點(diǎn)數(shù):", valid_mask.sum().item())
print("測(cè)試集節(jié)點(diǎn)數(shù):", test_mask.sum().item())
print("節(jié)點(diǎn)特征維數(shù):", n_features)
print("標(biāo)簽類目數(shù):", n_labels)隨機(jī)抽200個(gè)節(jié)點(diǎn)并畫圖展示:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt G = graph.to_networkx()
res = np.random.randint(0, high=G.number_of_nodes(), size=(200))k = G.subgraph(res)
pos = nx.spring_layout(k)plt.figure()
nx.draw(k, pos=pos, node_size=8 )
plt.savefig('cora.jpg', dpi=600)
plt.show()3 訓(xùn)練模型與評(píng)估
def evaluate(model, graph, features, labels, mask):model.eval()with torch.no_grad():logits = model(graph, features)logits = logits[mask]labels = labels[mask]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)model = SAGE(in_feats=n_features, hid_feats=128, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())# 開始訓(xùn)練
best_val_acc = 0
for epoch in range(200): print('Epoch {}'.format(epoch))model.train()# 用所有的節(jié)點(diǎn)進(jìn)行前向傳播logits = model(graph, node_features)# 計(jì)算損失loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])# 計(jì)算驗(yàn)證集accuracyacc = evaluate(model, graph, node_features, node_labels, valid_mask)# backward propagationopt.zero_grad()loss.backward()opt.step()print('loss = {:.4f}'.format(loss.item()))if acc > best_val_acc:best_val_acc = acctorch.save(model.state_dict(), 'save_model/best_model.pth')print("current val acc = {}, best val acc = {}".format(acc, best_val_acc))測(cè)試集評(píng)估
model.load_state_dict(torch.load("save_model/best_model.pth"))
acc = evaluate(model, graph, node_features, node_labels, test_mask)
print("test accuracy: ", acc)完結(jié):-) 覺得有用記得雙擊點(diǎn)贊呀!
總結(jié)
以上是生活随笔為你收集整理的pytorch 训练过程acc_【图节点分类】10分钟就学会的图节点分类教程,基于pytorch和dgl...的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 中国拍卖最贵的画是谁画的呢?
- 下一篇: getinstance方法详解_二、设计