torch_geometric笔记:数据集 ENZYMES Minibatches
?????????Pytorch Geometric中包含大量的常見基準數據集。在初始化數據集的時候,框架會自動下載數據集的原始文件,并將其處理為Data對象。例如要下載ENZYMES數據集(由600個graph劃分為6個類別)
1 下載數據集
from torch_geometric.datasets import TUDatasetdataset = TUDataset(root='', name='ENZYMES')dataset #ENZYMES(600)type(dataset) #torch_geometric.datasets.tu_dataset.TUDatasetlen(dataset) #600 #說明600張圖dataset.num_classes #6 #圖一共有6各不同的類dataset.num_node_features #3 每一個節點有三個特征data = dataset[0] data #Data(edge_index=[2, 168], x=[37, 3], y=[1]) #第一張圖有168條有向邊,37個節點,每個節點3個特征,整張圖有一個類別data.is_undirected() #True2?Mini-batches
????????神經網絡通常以batch的方式進行訓練,geometric在mini-batch實現了并行化,這種組合允許在一個batch中使用不同數量的邊和節點。
????????在torch_geometric.data.DataLoader中,已經包含了此過程。
????????這種mini-batch的操作本質上來說是將一個batch的graph看成是一個大的graph,由此,無論batch size是多少,其將所有的操作都統一在一個大圖上進行操作。
from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoaderdataset = TUDataset(root='', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True)for batch in loader:print(batch,batch.num_graphs)''' Batch(edge_index=[2, 3890], x=[1075, 21], y=[32], batch=[1075], ptr=[33]) 32 Batch(edge_index=[2, 4284], x=[1157, 21], y=[32], batch=[1157], ptr=[33]) 32 Batch(edge_index=[2, 4098], x=[1086, 21], y=[32], batch=[1086], ptr=[33]) 32 Batch(edge_index=[2, 3668], x=[916, 21], y=[32], batch=[916], ptr=[33]) 32 Batch(edge_index=[2, 4062], x=[1074, 21], y=[32], batch=[1074], ptr=[33]) 32 Batch(edge_index=[2, 4086], x=[1096, 21], y=[32], batch=[1096], ptr=[33]) 32 Batch(edge_index=[2, 3954], x=[1005, 21], y=[32], batch=[1005], ptr=[33]) 32 Batch(edge_index=[2, 4170], x=[1064, 21], y=[32], batch=[1064], ptr=[33]) 32 Batch(edge_index=[2, 4258], x=[1149, 21], y=[32], batch=[1149], ptr=[33]) 32 Batch(edge_index=[2, 3836], x=[997, 21], y=[32], batch=[997], ptr=[33]) 32 Batch(edge_index=[2, 3886], x=[1016, 21], y=[32], batch=[1016], ptr=[33]) 32 Batch(edge_index=[2, 4066], x=[1042, 21], y=[32], batch=[1042], ptr=[33]) 32 Batch(edge_index=[2, 3946], x=[1046, 21], y=[32], batch=[1046], ptr=[33]) 32 Batch(edge_index=[2, 3656], x=[927, 21], y=[32], batch=[927], ptr=[33]) 32 Batch(edge_index=[2, 4110], x=[1034, 21], y=[32], batch=[1034], ptr=[33]) 32 Batch(edge_index=[2, 3824], x=[1002, 21], y=[32], batch=[1002], ptr=[33]) 32 Batch(edge_index=[2, 4178], x=[1116, 21], y=[32], batch=[1116], ptr=[33]) 32 Batch(edge_index=[2, 3736], x=[974, 21], y=[32], batch=[974], ptr=[33]) 32 Batch(edge_index=[2, 2856], x=[804, 21], y=[24], batch=[804], ptr=[25]) 24 '''以? Batch(edge_index=[2, 3890], x=[1075, 21], y=[32], batch=[1075], ptr=[33])? 為例:
- edge_index=[2, 3890]——這個batch一共3890條邊
- x=[1075, 21]——整個batch的節點特征矩陣,這個batch一共2075個點,至于這個21,我不太明白,是因為不同的圖有不同的特征,所以拼起來一共21個不同的特征嗎?歡迎大家在評論區指正!
- y=[32]——32個圖,32維特征
- batch=[1075]——batch是一個列向量,它將每個節點映射到該batch中的對應的graph:
? ? ? ?
????????至于這個ptr,查了很多資料,都沒一個說法。
????????于是我自己做了一些嘗試:感覺可能是這個意思(歡迎指正哈):就是這個batch目前累計看到的圖的節點數量
? ? ? ? 因為實驗是后來補的,所以了不同的圖
from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoaderdataset = TUDataset(root='', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True)for batch in loader:print(batch,batch.num_graphs)break #Batch(edge_index=[2, 3822], x=[980, 21], y=[32], batch=[980], ptr=[33]) 32batch['ptr'] ''' tensor([ 0, 41, 66, 78, 122, 151, 193, 229, 261, 284, 350, 397, 429, 453,493, 534, 576, 588, 605, 634, 644, 660, 693, 723, 763, 811, 834, 847,887, 899, 940, 956, 980]) '''sum=[0] for i in range(32):sum.append(sum[-1]+int(batch[i]['x'].shape[0])) print(sum)''' [0, 41, 66, 78, 122, 151, 193, 229, 261, 284, 350, 397, 429, 453, 493, 534, 576, 588, 605, 634, 644, 660, 693, 723, 763, 811, 834, 847, 887, 899, 940, 956, 980] '''2.1 自己的圖列表 &DataLoader
不難發現,這種下載的數據集,可以看成是圖的集合
那么如果我門自己設計了一些圖,集合成一個列表,我們可以直接用這個列表構造DataLoader(注:這里的DataLoader是torch_geometric.loader的DataLoader)
?
總結
以上是生活随笔為你收集整理的torch_geometric笔记:数据集 ENZYMES Minibatches的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch 笔记:torch_geo
- 下一篇: torch_geometric 笔记: