gcn代码pytorch_GCN的简单实现(pytorch)
生活随笔
收集整理的這篇文章主要介紹了
gcn代码pytorch_GCN的简单实现(pytorch)
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
import torch
import torch.nn as nn
import torch.nn.functional as Fimport networkx as nxdef normalize(A , symmetric=True):# A = A+IA = A + torch.eye(A.size(0))# 所有節(jié)點的度d = A.sum(1)if symmetric:#D = D^-1/2D = torch.diag(torch.pow(d , -0.5))return D.mm(A).mm(D)else :# D=D^-1D =torch.diag(torch.pow(d,-1))return D.mm(A)class GCN(nn.Module):'''Z = AXW'''def __init__(self , A, dim_in , dim_out):super(GCN,self).__init__()self.A = Aself.fc1 = nn.Linear(dim_in ,dim_in,bias=False)self.fc2 = nn.Linear(dim_in,dim_in//2,bias=False)self.fc3 = nn.Linear(dim_in//2,dim_out,bias=False)def forward(self,X):'''計算三層gcn'''X = F.relu(self.fc1(self.A.mm(X)))X = F.relu(self.fc2(self.A.mm(X)))return self.fc3(self.A.mm(X))#獲得空手道俱樂部數(shù)據(jù)
G = nx.karate_club_graph()
A = nx.adjacency_matrix(G).todense()
#A需要正規(guī)化
A_normed = normalize(torch.FloatTensor(A),True)N = len(A)
X_dim = N# 沒有節(jié)點的特征,簡單用一個單位矩陣表示所有節(jié)點
X = torch.eye(N,X_dim)
# 正確結(jié)果
Y = torch.zeros(N,1).long()
# 計算loss的時候要去掉沒有標記的樣本
Y_mask = torch.zeros(N,1,dtype=torch.uint8)
# 一個分類給一個樣本
Y[0][0]=0
Y[N-1][0]=1
#有樣本的地方設(shè)置為1
Y_mask[0][0]=1
Y_mask[N-1][0]=1#真實的空手道俱樂部的分類數(shù)據(jù)
Real = torch.zeros(34 , dtype=torch.long)
for i in [1,2,3,4,5,6,7,8,11,12,13,14,17,18,20,22] :Real[i-1] = 0
for i in [9,10,15,16,19,21,23,24,25,26,27,28,29,30,31,32,33,34] :Real[i-1] = 1# 我們的GCN模型
gcn = GCN(A_normed ,X_dim,2)
#選擇adam優(yōu)化器
gd = torch.optim.Adam(gcn.parameters())for i in range(300):#轉(zhuǎn)換到概率空間y_pred =F.softmax(gcn(X),dim=1)#下面兩行計算cross entropyloss = (-y_pred.log().gather(1,Y.view(-1,1)))#僅保留有標記的樣本loss = loss.masked_select(Y_mask).mean()#梯度下降#清空前面的導(dǎo)數(shù)緩存gd.zero_grad()#求導(dǎo)loss.backward()#一步更新gd.step()if i%20==0 :_,mi = y_pred.max(1)print(mi)#計算精確度print((mi == Real).float().mean())
總結(jié)
以上是生活随笔為你收集整理的gcn代码pytorch_GCN的简单实现(pytorch)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python中locals函数_Pyth
- 下一篇: 计算机基础及应用笔试,计算机基础及应用测