GAT模型介绍
在GCN里介紹了處理cora數(shù)據(jù)集,以及返回的結(jié)果:
- features:論文的屬性特征,維度2708 × 1433 2708 \times 14332708×1433,并且做了歸一化,即每一篇論文屬性值的和為1.
- labels:每一篇論文對應的分類編號:0-6
- adj:鄰接矩陣,維度2708 × 2708 2708 \times 27082708×2708
- idx_train:0-139
- idx_val:200-499
- idx_test:500-1499
這一節(jié)介紹GAT模型:
GAT模型
model:
class GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):"""Dense version of GAT."""super(GAT, self).__init__()self.dropout = dropoutself.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]for i, attention in enumerate(self.attentions):self.add_module('attention_{}'.format(i), attention)self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) # 第二層(最后一層)的attention layerdef forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.attentions], dim=1) # 將每層attention拼接x = F.dropout(x, self.dropout, training=self.training)x = F.elu(self.out_att(x, adj)) # 第二層的attention layerreturn F.log_softmax(x, dim=1)layers:
class GraphAttentionLayer(nn.Module):"""Simple GAT layer, similar to https://arxiv.org/abs/1710.10903"""def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropoutself.in_features = in_featuresself.out_features = out_featuresself.alpha = alphaself.concat = concatself.W = nn.Parameter(torch.empty(size=(in_features, out_features)))nn.init.xavier_uniform_(self.W.data, gain=1.414)self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) # concat(V,NeigV)nn.init.xavier_uniform_(self.a.data, gain=1.414)self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, h, adj):Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)a_input = self._prepare_attentional_mechanism_input(Wh) # 每一個節(jié)點和所有節(jié)點,特征。(Vall, Vall, feature)e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) # a_input.shape=(2708,2708,16) self.a.shape=(16,l) numpy.matmul(a_input,self.a) shape=(2708, 2708, 1) squeeze表示去掉最后一個維度, (2708,2708)# 之前計算的是一個節(jié)點和所有節(jié)點的attention,其實需要的是連接的節(jié)點的attention系數(shù)zero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec) # 將鄰接矩陣中小于0的變成負無窮attention = F.softmax(attention, dim=1) # 按行求softmax。 sum(axis=1) === 1attention = F.dropout(attention, self.dropout, training=self.training)h_prime = torch.matmul(attention, Wh) # 聚合鄰居函數(shù)if self.concat:return F.elu(h_prime) # elu-激活函數(shù)else:return h_primedef _prepare_attentional_mechanism_input(self, Wh):N = Wh.size()[0] # number of nodes# Below, two matrices are created that contain embeddings in their rows in different orders.# (e stands for embedding)# These are the rows of the first matrix (Wh_repeated_in_chunks): # e1, e1, ..., e1, e2, e2, ..., e2, ..., eN, eN, ..., eN# '-------------' -> N times '-------------' -> N times '-------------' -> N times# # These are the rows of the second matrix (Wh_repeated_alternating): # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN # '----------------------------------------------------' -> N times# Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0) # 復制Wh_repeated_alternating = Wh.repeat(N, 1)# Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)# The all_combination_matrix, created below, will look like this (|| denotes concatenation):# e1 || e1# e1 || e2# e1 || e3# ...# e1 || eN# e2 || e1# e2 || e2# e2 || e3# ...# e2 || eN# ...# eN || e1# eN || e2# eN || e3# ...# eN || eNall_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)# all_combinations_matrix.shape == (N * N, 2 * out_features)return all_combinations_matrix.view(N, N, 2 * self.out_features)def __repr__(self):return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'初始化模型
model = GAT(nfeat=1433, nhid=8, nclass=7, dropout=0.6, nheads=8, alpha=0.2)構(gòu)建attention:
self.dropout = 0.6self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] for i, attention in enumerate(self.attentions):self.add_module('attention_{}'.format(i), attention)self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) # 第二層(最后一層)的attention layerattentions和out_att
首先構(gòu)建attentions層,主要包括8個GraphAttentionLayer,每一個GraphAttentionLayer如下:
def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = 0.6self.in_features = 1433self.out_features = 8self.alpha = 0.2self.concat = Trueself.W = nn.Parameter(torch.empty(size=(1433, 8)))nn.init.xavier_uniform_(self.W.data, gain=1.414) # 初始化Wself.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) # concat(V,NeigV)nn.init.xavier_uniform_(self.a.data, gain=1.414) # 初始化aself.leakyrelu = nn.LeakyReLU(0.2)參數(shù)W的維度是W1433×8W_{1433 \times 8}W1433×8?
參數(shù)a的維度是a16×1a_{16 \times 1}a16×1?
out_att
而out_att與attention相似,區(qū)別是out_att只有一個GraphAttentionLayer,而且參數(shù)也有所不同:
def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = 0.6self.in_features = 64self.out_features = 7self.alpha = 0.2self.concat = Falseself.W = nn.Parameter(torch.empty(size=(64, 7)))nn.init.xavier_uniform_(self.W.data, gain=1.414) # 初始化Wself.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) # concat(V,NeigV)nn.init.xavier_uniform_(self.a.data, gain=1.414) # 初始化aself.leakyrelu = nn.LeakyReLU(0.2)參數(shù)W的維度是W64×7W_{64 \times 7}W64×7?
參數(shù)a的維度是a14×1a_{14 \times 1}a14×1?
forward執(zhí)行模型
將數(shù)據(jù)特征h與第一個attention的權(quán)重w相乘,得到結(jié)果wh=x2708×1433×W1433×8w_h=x_{2708 \times 1433} \times W_{1433 \times 8}wh?=x2708×1433?×W1433×8?,whw_hwh?的維度是2708×82708 \times 82708×8,然后執(zhí)行self._prepare_attentional_mechanism_input(Wh)
5. 執(zhí)行self._prepare_attentional_mechanism_input(Wh):
假如Wh[0]的內(nèi)容如下:
Wh[0] tensor([-0.0118, -0.0033, -0.0051, 0.0151, -0.0151, 0.0186, -0.0097, 0.0387],grad_fn=<SelectBackward>)那么經(jīng)過Wh_repeated_in_chunks=Wh.repeat_interleave(N, dim=0)后,Wh_repeated_in_chunks的維度變?yōu)?span id="ozvdkddzhkzd" class="katex--inline">2708?2708×82708*2708 \times 82708?2708×8,且Wh_repeated_in_chunks[0]到Wh_repeated_in_chunks[2707]的數(shù)據(jù)與Wh[0]一致,
Wh_repeated_in_chunks[2708]到Wh_repeated_in_chunks[2707+2708]的數(shù)據(jù)與Wh[1]一致,以此類推。
經(jīng)過Wh_repeated_alternating = Wh.repeat(N, 1)后,Wh_repeated_alternating維度變?yōu)?span id="ozvdkddzhkzd" class="katex--inline">2708?2708×82708*2708 \times 82708?2708×8,且Wh_repeated_alternating[0]到Wh_repeated_alternating[2707]的數(shù)據(jù)與Wh[0]到Wh[2707]一致,形式如下:
# e1, e1, ..., e1, e2, e2, ..., e2, ..., eN, eN, ..., eN# '-------------' -> N times '-------------' -> N times '-------------' -> N timesWh_repeated_alternating[2708]到Wh_repeated_alternating[2707+2708]的數(shù)據(jù)與Wh[2708]到Wh[2707+2708]一致,形式如下:
# These are the rows of the second matrix (Wh_repeated_alternating): # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN # '----------------------------------------------------' -> N timesall_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)的作用是將上面的Wh_repeated_alternating與Wh_repeated_alternating拼接起來,維度為[2708?2708,16][2708*2708,16][2708?2708,16],形式如下:
# e1 || e1 # e1 || e2 # e1 || e3 # ... # e1 || eN # e2 || e1 # e2 || e2 # e2 || e3 # ... # e2 || eN # ... # eN || e1 # eN || e2 # eN || e3 # ... # eN || eN返回結(jié)果a_input的維度被改為[2708,2708,16][2708,2708,16][2708,2708,16],a_input格式如下,a_input[0]表示第0個數(shù)據(jù)與其他數(shù)據(jù)的特征拼接;a_input[1]表示第1個數(shù)據(jù)與其他數(shù)據(jù)的特征拼接;
6. e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))計算attention,a_input的維度是[2708,2708,16][2708, 2708, 16][2708,2708,16],self.a的維度是[16,1][16, 1][16,1],相乘得到維度[2708,2708,1][2708, 2708, 1][2708,2708,1],這個表示每一個節(jié)點與其他所有節(jié)點的attention的值,squeeze表示去掉最后一個維度, 也就是維度變?yōu)?span id="ozvdkddzhkzd" class="katex--inline">[2708,2708][2708,2708][2708,2708]
7. zero_vec = -9e15*torch.ones_like(e), 之前計算的是一個節(jié)點和所有節(jié)點的attention,其實需要的是連接的節(jié)點的attention系數(shù),所以生成與e同樣的結(jié)構(gòu)的矩陣zero_vec
8. attention = torch.where(adj > 0, e, zero_vec) 將鄰接矩陣中小于0的變成負無窮,形成與鄰接矩陣shape相同的attention矩陣,每個值表示該節(jié)點與其他節(jié)點的attention值
9. attention = F.softmax(attention, dim=1),對attention矩陣的每一行求softmax,求得關(guān)聯(lián)度最大的概率節(jié)點
10. attention = F.dropout(attention, self.dropout, training=self.training)對attention矩陣進行dropout
11. h_prime = torch.matmul(attention, Wh)將最終得到的attention矩陣與當前attention的權(quán)重矩陣相乘
12. 將結(jié)果h_prime進行激活函數(shù)elu,h_prime的維度是2708×82708 \times 82708×8
13. 這樣就完成了一次attention,總共有8次attention
14. 得到8個h_prime,將這8個h_prime拼接起來,得到x的shape是[2708, 64]
總結(jié)
- 上一篇: GCN(二)GCN模型介绍
- 下一篇: pip download的使用记录