PYG教程【三】对Cora数据集进行半监督节点分类
Cora數(shù)據(jù)集
PyG包含有大量的基準(zhǔn)數(shù)據(jù)集。初始化數(shù)據(jù)集非常簡(jiǎn)單,數(shù)據(jù)集初始化會(huì)自動(dòng)下載原始數(shù)據(jù)文件,并且會(huì)將它們處理成Data格式。
如下圖所示,Cora數(shù)據(jù)集中只有一個(gè)圖,該圖包含2708個(gè)節(jié)點(diǎn),10556條邊,節(jié)點(diǎn)類別數(shù)為7,特征維度為1433。并且默認(rèn)已經(jīng)對(duì)數(shù)據(jù)集進(jìn)行了劃分,分為了訓(xùn)練集、驗(yàn)證集和測(cè)試集。
然后看看節(jié)點(diǎn)特征和標(biāo)簽。x為節(jié)點(diǎn)特征矩陣,維度為2708*1433。y為節(jié)點(diǎn)標(biāo)簽向量,維度為2708,類別為7。
用GCN進(jìn)行半監(jiān)督節(jié)點(diǎn)分類
接下來(lái)就可以構(gòu)建一個(gè)簡(jiǎn)單的GCN模型,在Cora數(shù)據(jù)集上進(jìn)行半監(jiān)督節(jié)點(diǎn)分類。
下面的GCN模型包含兩個(gè)圖卷積層。第一層輸入維度為1433(節(jié)點(diǎn)特征維度),輸出為16(與第一層輸出一致),后面接上一個(gè)relu激活函數(shù),以及dropout操作。第二層輸入維度為16,輸出為7(節(jié)點(diǎn)標(biāo)簽數(shù)量),后接log_softmax函數(shù)進(jìn)行分類。
模型構(gòu)建完成后,指定訓(xùn)練設(shè)備為GPU(沒(méi)有的話就用CPU),注意這里默認(rèn)使用的是0號(hào)cuda。如果cuda:0被占用了的話會(huì)報(bào)錯(cuò),需要指定其他號(hào)碼的cuda才能運(yùn)行。然后,分別將GCN模型以及Cora圖數(shù)據(jù)送入指定的設(shè)備。
優(yōu)化器選擇Adam,學(xué)習(xí)率設(shè)置為0.01,權(quán)重衰減設(shè)置為5e-4。這些都配置好以后就可以訓(xùn)練模型了,epoch設(shè)為200,每個(gè)epoch后清除上次的梯度信息,然后用nll_loss計(jì)算出訓(xùn)練集上的損失,調(diào)用backward函數(shù)計(jì)算出梯度后傳回給Adam優(yōu)化器進(jìn)行參數(shù)更新。
最后在測(cè)試集上評(píng)估模型,計(jì)算分類正確率accuracy并顯示。
至此,就完成了Cora數(shù)據(jù)集上的節(jié)點(diǎn)分類任務(wù)了。
總結(jié)
以上是生活随笔為你收集整理的PYG教程【三】对Cora数据集进行半监督节点分类的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 人工智能技术应用学什么
- 下一篇: PYG教程【四】Node2Vec节点分类