图融合GCN(Graph Convolutional Networks)
圖融合GCN(Graph Convolutional Networks)
數(shù)據(jù)其實(shí)是圖(graph),圖在生活中無處不在,如社交網(wǎng)絡(luò),知識圖譜,蛋白質(zhì)結(jié)構(gòu)等。本文介紹GNN(Graph Neural Networks)中的分支:GCN(Graph Convolutional Networks)。
GCN的PyTorch實(shí)現(xiàn)
雖然GCN從數(shù)學(xué)上較難理解,但是,實(shí)現(xiàn)是非常簡單的,值得注意的一點(diǎn)是,一般情況下鄰接矩陣是稀疏矩陣,所以,在實(shí)現(xiàn)矩陣乘法時(shí),采用稀疏運(yùn)算會(huì)更高效。首先,圖卷積層的實(shí)現(xiàn):
import torch
import torch.nn as nn
class GraphConvolution(nn.Module):"""GCN layer"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = nn.Parameter(torch.Tensor(in_features, out_features))if bias:self.bias = nn.Parameter(torch.Tensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):nn.init.kaiming_uniform_(self.weight)if self.bias isnotNone:nn.init.zeros_(self.bias)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias isnotNone:return output + self.biaselse:return outputdef extra_repr(self):return'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias isnotNone)
對于GCN,只需要將圖卷積層堆積起來就可以,這里,實(shí)現(xiàn)一個(gè)兩層的GCN:
class GCN(nn.Module):
“”“a simple two layer GCN”""
def init(self, nfeat, nhid, nclass):
super(GCN, self).init()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)
def forward(self, input, adj):h1 = F.relu(self.gc1(input, adj))logits = self.gc2(h1, adj)return logits
這里的激活函數(shù)采用ReLU,后面,將用這個(gè)網(wǎng)絡(luò)實(shí)現(xiàn)一個(gè)圖中節(jié)點(diǎn)的半監(jiān)督分類任務(wù)。
數(shù)據(jù)的提取,只需要load就可以:
https://github.com/tkipf/pygcn/blob/master/pygcn/utils.py
adj, features, labels, idx_train, idx_val, idx_test = load_data(path="./data/cora/")
值得注意的有兩點(diǎn),一是論文引用應(yīng)該是單向圖,但是在網(wǎng)絡(luò)時(shí),要先將其轉(zhuǎn)成無向圖,或者說建立雙向引用,這個(gè)對模型訓(xùn)練結(jié)果影響較大:
build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
另外,官方實(shí)現(xiàn)中對鄰接矩陣采用的是普通均值歸一化,當(dāng)然,也可以采用對稱歸一化方式:
def normalize_adj(adj):
“”“compute L=D^-0.5 * (A+I) * D^-0.5"”"
adj += sp.eye(adj.shape[0])
degree = np.array(adj.sum(1))
d_hat = sp.diags(np.power(degree, -0.5).flatten())
norm_adj = d_hat.dot(adj).dot(d_hat)
return norm_adj
這里,只采用圖中140個(gè)有標(biāo)簽樣本對GCN進(jìn)行訓(xùn)練,每個(gè)epoch計(jì)算出這些節(jié)點(diǎn)特征,然后計(jì)算loss:
loss_history = []
val_acc_history = []
for epoch in range(epochs):
model.train()
logits = model(features, adj)
loss = criterion(logits[idx_train], labels[idx_train])
train_acc = accuracy(logits[idx_train], labels[idx_train])optimizer.zero_grad()loss.backward()optimizer.step()val_acc = test(idx_val)loss_history.append(loss.item())val_acc_history.append(val_acc.item())print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(epoch, loss.item(), train_acc.item(), val_acc.item()))
只需要訓(xùn)練200個(gè)epoch,就可以在測試集上達(dá)到80%左右的分類準(zhǔn)確,GCN的強(qiáng)大可想而知:
融合BN和Conv層
在PyTorch中實(shí)現(xiàn)這個(gè)融合操作:nn.Conv2d參數(shù):
? filter weights,W: conv.weight;
? bias,b: conv.bias;
nn.BatchNorm2d參數(shù):
具體的實(shí)現(xiàn)代碼如下(Google Colab, https://colab.research.google.com/drive/1mRyq_LlJW4u_rArzzhEe_T6tmEWoNN1K):
import torch
import torchvision
def fuse(conv, bn):fused = torch.nn.Conv2d(conv.in_channels,conv.out_channels,kernel_size=conv.kernel_size,stride=conv.stride,padding=conv.padding,bias=True)# setting weightsw_conv = conv.weight.clone().view(conv.out_channels, -1)w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )# setting biasif conv.bias isnotNone:b_conv = conv.biaselse:b_conv = torch.zeros( conv.weight.size(0) )b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))fused.bias.copy_( b_conv + b_bn )return fused# Testing
# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(resnet18.conv1,resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0], model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)
參考鏈接:
- Semi-Supervised Classification with Graph Convolutional Networks https://arxiv.org/abs/1609.02907
- How to do Deep Learning on Graphs with Graph Convolutional Networks https://towardsdatascience.com/how-to-do-deep-learning-on-graphs-with-graph-convolutional-networks-7d2250723780
- Graph Convolutional Networks http://tkipf.github.io/graph-convolutional-networks
- Graph Convolutional Networks in PyTorch https://github.com/tkipf/pygcn
- 回顧頻譜圖卷積的經(jīng)典工作:從ChebNet到GCN https://www.jianshu.com/p/2fd5a2454781
- 圖數(shù)據(jù)集之cora數(shù)據(jù)集介紹- 用pyton處理 - 可用于GCN任務(wù) https://blog.csdn.net/yeziand01/article/details/93374216
- Speeding up model with fusing batch normalization and convolution (http://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3)
總結(jié)
以上是生活随笔為你收集整理的图融合GCN(Graph Convolutional Networks)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 智能驾驶技术方向
- 下一篇: 鸿蒙OS:万物互联,方舟Compiler