toch_geometric 笔记:message passing GCNConv
?1?message passing介紹
? ? ? ? 將卷積算子推廣到不規(guī)則域通常表示為一個(gè)鄰域聚合(neighborhood aggregation)或消息傳遞(message passing?)方案
? ? ? ? 給定第(k-1)層點(diǎn)的特征,以及可能有的點(diǎn)與點(diǎn)之間邊的特征,依靠信息傳遞的GNN可以被描述成:
?
?其中表示一個(gè)可微分的可微,置換不變的函數(shù)(比如sum、mean或者max),γ和Φ表示可微分方程(比如MLP)
2 message passing 類??
????????PyG提供了message?passing基類,它通過自動(dòng)處理消息傳播來幫助創(chuàng)建這類消息傳遞圖神經(jīng)網(wǎng)絡(luò)。
? ? ? ? 使用者只需要定義γ(update函數(shù))和Φ(message函數(shù)),以及聚合方式aggr(即)【aggr="add",?aggr="mean"?or?aggr="max"】即可
2.1?MessagePassing
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)?定義了聚合方式(這里是’add‘)
信息傳遞的流方向("source_to_target"?【默認(rèn)】or?"target_to_source")
node_dim表示了沿著哪個(gè)軸進(jìn)行傳遞
2.2?MessagePassing.propagate
MessagePassing.propagate(edge_index, size=None, **kwargs)????????開始傳播消息的初始調(diào)用。
????????獲取邊索引(edge index)和所有額外的數(shù)據(jù),這些數(shù)據(jù)是構(gòu)造消息和更新節(jié)點(diǎn)嵌入所需要的。
? ? ? ? propagate()不僅可以在[N,N]的鄰接方陣中傳遞消息,還可以在非方陣中傳遞消息,(比如二部圖[N,M],此時(shí)設(shè)置size=(N,M)作為額外的形參)
? ? ? ? 如果size參數(shù)設(shè)置為None,那么矩陣默認(rèn)是一個(gè)方陣。
? ? ? ? 對于二部圖[N,M]來說,它有兩組互相獨(dú)立的點(diǎn)集,我們還需要設(shè)置x=(x_N,x_M)
2.3?MessagePassing.message(...)
? ? ? ? 類似于Φ。將信息傳遞到節(jié)點(diǎn)i上。?如果flow="source_to_target",那么是找所有(j,i)∈E;如果flow="target_to_source",那么找所有(i,j)屬于E。
????????可以接受最初傳遞給propagate()的任何參數(shù)。
????????此外,傳遞給propagate()的張量可以通過在變量名后面附加_i或_j,映射到各自的節(jié)點(diǎn)。例如,x_i(表示中心節(jié)點(diǎn))、?x_j(表示鄰居節(jié)點(diǎn))。
????????注意,我們通常將i稱為匯聚信息的中心節(jié)點(diǎn),將j稱為相鄰節(jié)點(diǎn),因?yàn)檫@是最常見的表示法。
2.4?MessagePassing.update(aggr_out,?...)
? ? ? ? 類比γ,對每個(gè)點(diǎn)i∈ V,更新它的node embedding
? ? ? ? 第一個(gè)參數(shù)是聚合輸出,同時(shí)將所有傳遞給propagate()的參數(shù)作為后續(xù)參數(shù)
3 舉例: GCN
3.1 GCN回顧
GCN層可以表示為:
?????????k-1層的鄰居節(jié)點(diǎn)先通過權(quán)重矩陣Θ加權(quán),然后用中心節(jié)點(diǎn)和這個(gè)鄰居節(jié)點(diǎn)的度來進(jìn)行歸一化,最后求和聚合 。
3.2 message passing 實(shí)現(xiàn)過程
? ? ? ? 這個(gè)方程可以劃分成以下幾個(gè)步驟
????????步驟1~3在message passing開始前就已經(jīng)計(jì)算完畢了;步驟4,5則可以用MessagePassing操作來進(jìn)行處理 。
3.3 代碼解析
import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='add') # "Add" aggregation (Step 5).#GCN類從MessagePssing中繼承得到的聚合方式:“add”self.lin = torch.nn.Linear(in_channels, out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels] ——N個(gè)點(diǎn),每個(gè)點(diǎn)in_channels維屬性# edge_index has shape [2, E]——E條邊,每條邊有出邊和入邊# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))#添加自環(huán)# Step 2: Linearly transform node feature matrix.x = self.lin(x)#對X進(jìn)行線性變化# Step 3: Compute normalization.row, col = edge_index#出邊和入邊deg = degree(col, x.size(0), dtype=x.dtype)#各個(gè)點(diǎn)的入度(無向圖,所以入讀和出度相同)deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]#1/sqrt(di) *1/sqrt(dj)# Step 4-5: Start propagating messages.return self.propagate(edge_index, x=x, norm=norm)#進(jìn)行propagate#propagate的內(nèi)部會(huì)調(diào)用message(),aggregate()和update()#作為消息傳播的附加參數(shù),我們傳遞節(jié)點(diǎn)嵌入x和標(biāo)準(zhǔn)化系數(shù)norm。 def message(self, x_j, norm):# x_j has shape [E, out_channels]#我們需要對相鄰節(jié)點(diǎn)特征x_j進(jìn)行norm標(biāo)準(zhǔn)化#這里x_j為一個(gè)張量,其中包含每條邊的源節(jié)點(diǎn)特征,即每個(gè)節(jié)點(diǎn)的鄰居。# Step 4: Normalize node features.return norm.view(-1, 1) * x_j#1/sqrt(di) *1/sqrt(dj) *X_j? 之后,我們就可以用這種方法輕松調(diào)用了:
conv = GCNConv(16, 32) x = conv(x, edge_index)總結(jié)
以上是生活随笔為你收集整理的toch_geometric 笔记:message passing GCNConv的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: torch_geometric 笔记:
- 下一篇: torch_geometric 笔记:T