GraphSAGE节点分类
GraphSAGE 節點分類
簡介
GCN(Graph Convolutional Network)的出現帶動了將神經網絡技術用于圖數據的學習任務中去,并產生了大量的方法,這類方法我們統稱為圖神經網絡(Graph Neural Networks,GNN)。我們知道,圖卷積可以從譜域和空域兩個角度看待(盡管后來逐漸深入的研究表明,所謂的譜域圖卷積其實就是特殊的空域圖卷積而已),從空域來看,GCN 的本質就是一個迭代式地聚合鄰居的過程,這個思路啟發了一大類模型對于這種聚合操作的重新設計,比如比較有名的 GraphSAGE、GAT、R-GCN,這些以空域視角出發的圖網絡算法,通常被叫做空域圖卷積。本文在本系列上一篇文章GCN 節點分類的基礎上,使用 Pytorch 實現 GraphSAGE 對 Cora 數據集進行節點分類。
GraphSAGE 算法簡述
GraphSAGE 其實在兩個方面對 GCN 做了改動,一方面是通過采樣鄰居的策略將 GCN 由全圖(full batch)的訓練方式改造成以節點為中心的小批量(mini batch)的訓練方式,這使得大規模圖數據的分布式訓練成為了可能;另一方面,GraphSAGE 對聚合鄰居的操作進行了拓展,提出了替換 GCN 操作的新的方式。
采樣鄰居
GCN 的訓練是全圖形式的,就是說一輪迭代,所有節點的樣本的損失只會貢獻一次梯度,無法做到深度神經網絡中常用的小批量更新,從梯度更新的次數來看這是很低效的。這還不是重點,事實上,實際業務中,圖的規模是巨大的,顯存或者內存很難容納下模型和整個圖,因此采用小批量的訓練方法是必要的。GraphSAGE 從聚合鄰居的操作出發,對鄰居進行隨機采樣來控制實際運算時節點kkk階子圖的數據規模,在此基礎上對采樣的子圖進行隨機組合來完成小批量訓練。
GCN 中,節點在第k+1k+1k+1層的特征只與其鄰居在kkk層的特征有關,這種局部性質導致節點在第kkk層的特征只與自己的kkk階子圖有關。雖然這樣說只需要考慮節點的kkk階子圖就可以完成對節點高層特征的計算,但是對于一個大規模圖數據而言,直接遷移此思路仍然存在一些問題:
上述的情況下,遍歷子圖的時間代價、模型訓練的計算和存儲代價都會難以把控。因此,GraphSAGE 使用了采樣鄰居的操作來控制子圖發散時的增長率。它的具體操作為:設每個節點在第kkk層的鄰居采樣倍率為SkS_kSk?(這是一個超參數),即每個節點采樣的一階鄰居不超過SkS_kSk?,那么對于任意一個中心節點的表達計算,所設計的總節點數將在O(∏k=1Ksk)O\left(\prod_{k=1}^{K} s_{k}\right)O(∏k=1K?sk?)這個級別。舉個例子,對一個兩層模型來說,如果S1=3S_1=3S1?=3,S2=2S_2=2S2?=2則總節點數不會超過1+3+3×2=101+3+3\times2=101+3+3×2=10個。這里對節點采樣,GraphSAGE 選擇了均勻分布,其實工程上會采用其他形式的分布。
通過采樣鄰居的策略,GraphSAGE 控制子圖節點的規模始終維持在階乘級別以下,這也給模型層數的增加提供了可能性。
聚合鄰居
GraphSAGE 研究了聚合鄰居操作所需的性質,提出了幾種新的聚合操作算子(aggregator),需滿足如下條件:
當然,從模型優化的角度看,這種聚合操作還必須可導。只要滿足上述性質,聚合操作就能對任意輸入的節點集合做到自適應。比較簡單的算子有平均/加和聚合算子、LSTM 聚合算子、池化聚合算子等,這里就不展開了,詳細可以參考原論文3.3 節。
GraphSAGE 算法過程
在上面兩個機制的基礎上,最后來看看 GraphSAGE 如何實現訓練的。
輸入:圖G(V,E)\mathcal{G}(\mathcal{V}, \mathcal{E})G(V,E);輸入特征{xv,?v∈B}\left\{\mathbf{x}_{v}, \forall v \in \mathcal{B}\right\}{xv?,?v∈B};層數KKK;權重矩陣Wk,?k∈{1,…,K}\mathbf{W}^{k}, \forall k \in\{1, \ldots, K\}Wk,?k∈{1,…,K};非線性函數σ\sigmaσ;聚合操作 AGGREGATE k,?k∈{1,…,K}_{k}, \forall k \in\{1, \ldots, K\}k?,?k∈{1,…,K};鄰居采樣函數Nk:v→2V,?k∈{1,…,K}\mathcal{N}_{k}: v \rightarrow 2^{\mathcal{V}}, \forall k \in\{1, \ldots, K\}Nk?:v→2V,?k∈{1,…,K}。
輸出:所有節點的向量表示zv\mathbf{z}_{v}zv?, v∈Bv \in \mathcal{B}v∈B。
小批量訓練過程如下:
上述算法的基本思路為先將小批集合B\mathcal{B}B內的中心節點聚合操作要涉及到的kkk階子圖一次性遍歷出來,然后在這些節點上進行KKK次聚合操作的迭代式計算。上述圖中的 1-7 行就是描述遍歷操作,可以簡單理解這個過程:要想得到某個中心節點第kkk層的特征,就需要采樣其在第k?1k-1k?1層的鄰居,然后對k?1k-1k?1層每個節點采樣其第k?2k-2k?2層的鄰居,以此類推,直到采樣完第一層所有的鄰居為止。注意,每層的采樣函數可以單獨設置。
上述算法圖的 9-15 行是第二步,聚合操作,其核心為 11-13 行的三個公式。第 11 行的式子是調用聚合操作完成對每個節點鄰居特征的整合輸出,第 12 行是將聚合后的鄰居特征與中心節點上一層的特征進行拼接,然后送到一個單層網絡里得到中心節點的特征向量,第 13 行對節點的特征向量進行歸一化。對這三行操作迭代KKK次就完成了對B\mathcal{B}B內所有中心節點特征向量的提取。
GraphSAGE 的算法過程完全沒有拉普拉斯矩陣的參與,每個節點的特征學習過程僅僅只與其kkk階鄰居相關,而不需要全圖對的結構西南西,這樣的方法適合做歸納學習(Inductive Learning),這也就是 GraphSAGE 論文題目 Inductive Representation Learning on Large Graphs 的由來。這里我就不多闡述歸納學習和轉導學習(Transductive Learning)的理論,需要知道的是,對 GraphSAGE 而言,新出現的節點數據,只需要遍歷得到kkk階子圖,就可以代入模型進行預測,這種特性使得 GraphSAGE 潛力巨大。
總的來說,GraphSAEG 對空域視角下的 GCN 作了一次解構,提出幾種鄰居聚合算子,同時通過采樣鄰居,大大改進了算法的性能,關于其更詳細的內容推薦閱讀原論文。
GraphSAGE 節點分類
本節使用 Pytorch 實現 GraphSAGE 對 Cora 數據集進行節點分類,通過代碼進一步理解 GraphSAGE。GraphSAGE 包括鄰居采樣和鄰居聚合兩個方面。
首先來看鄰居采樣,通過下面的兩個函數實現了一階和多階采樣,為了高效,節點和鄰居的關系維護一個表即可。
import numpy as npdef sampling(src_nodes, sample_num, neighbor_table):"""根據源節點一階采樣指定數量的鄰居,有放回:param src_nodes::param sample_num::param neighbor_table::return:"""results = []for sid in src_nodes:# 從節點的鄰居中進行有放回地進行采樣neighbor_nodes = neighbor_table.getrow(sid).nonzero()res = np.random.choice(np.array(neighbor_nodes).flatten(), size=sample_num)results.append(res)return np.asarray(results).flatten()def multihop_sampling(src_nodes, sample_nums, neighbor_table):"""根據源節點進行多階采樣:param src_nodes::param sample_nums::param neighbor_table::return:"""sampling_result = [src_nodes]for k, hopk_num in enumerate(sample_nums):hopk_result = sampling(sampling_result[k], hopk_num, neighbor_table)sampling_result.append(hopk_result)return sampling_result這樣的阿斗的結果是節點的 ID,還需要根據 ID 查詢節點的特征以進行聚合操作更新特征。
接著我們來看鄰居聚合,定義一個 Pytorch module 來完成聚合過程,輸入特征先是經過一個線性變換得到隱層特征,從而可以在第一個維度進行聚合操作,預定義了求和、均值、最大值等算子。
class NeighborAggregator(nn.Module):def __init__(self, input_dim, output_dim,use_bias=False, aggr_method="mean"):"""聚合節點鄰居:param input_dim: 輸入特征的維度:param output_dim: 輸出特征的維度:param use_bias: 是否使用偏置:param aggr_method: 鄰居聚合算子形式"""super(NeighborAggregator, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.use_bias = use_biasself.aggr_method = aggr_methodself.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))if self.use_bias:self.bias = nn.Parameter(torch.Tensor(self.output_dim))self.reset_parameters()def reset_parameters(self):init.kaiming_uniform_(self.weight)if self.use_bias:init.zeros_(self.bias)def forward(self, neighbor_feature):if self.aggr_method == "mean":aggr_neighbor = neighbor_feature.mean(dim=1)elif self.aggr_method == "sum":aggr_neighbor = neighbor_feature.sum(dim=1)elif self.aggr_method == "max":aggr_neighbor = neighbor_feature.max(dim=1)else:raise ValueError("Unknown aggr type, expected sum, max, or mean, but got {}".format(self.aggr_method))neighbor_hidden = torch.matmul(aggr_neighbor, self.weight)if self.use_bias:neighbor_hidden += self.biasreturn neighbor_hiddendef extra_repr(self):return 'in_features={}, out_features={}, aggr_method={}'.format(self.input_dim, self.output_dim, self.aggr_method)基于鄰居聚合的結果對中心節點的特征進行更新。更新的方式是將鄰居節點聚合的特征與經過線性變換的中心特征加和或者級聯,再經過一個激活函數得到更新后的特征,依次我們就可以實現新的 GCN 層。繼而,可以堆疊 SAGEGCN 來構建模型,實現訓練。
class GraphSAGE(nn.Module):def __init__(self, input_dim, hidden_dim,num_neighbors_list):super(GraphSAGE, self).__init__()self.input_dim = input_dimself.hidden_dim = hidden_dimself.num_neighbors_list = num_neighbors_listself.num_layers = len(num_neighbors_list)self.gcn = nn.ModuleList()self.gcn.append(SAGEGCN(input_dim, hidden_dim[0]))for index in range(0, len(hidden_dim) - 2):self.gcn.append(SAGEGCN(hidden_dim[index], hidden_dim[index + 1]))self.gcn.append(SAGEGCN(hidden_dim[-2], hidden_dim[-1], activation=None))def forward(self, node_features_list):hidden = node_features_listfor l in range(self.num_layers):next_hidden = []gcn = self.gcn[l]for hop in range(self.num_layers - l):src_node_features = hidden[hop]src_node_num = len(src_node_features)neighbor_node_features = hidden[hop + 1] \.view((src_node_num, self.num_neighbors_list[hop], -1))h = gcn(src_node_features, neighbor_node_features)next_hidden.append(h)hidden = next_hiddenreturn hidden[0]def extra_repr(self):return 'in_features={}, num_neighbors_list={}'.format(self.input_dim, self.num_neighbors_list)下圖是訓練過程可視化的結果,可以看到,GraphSAGE以mini batch的方式訓練,并在很少的輪次后獲得了和GCN相當的精度。
補充說明
本文關于 GraphSAGE 的理論以及代碼部分參考《深入淺出圖神經網絡》以及 GraphSAGE 論文原文。本文涉及到的代碼開源于Github,歡迎 star和fork。
總結
以上是生活随笔為你收集整理的GraphSAGE节点分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: GCN节点分类
- 下一篇: DCN RepPoints解读