学习报告:基于原型网络的小样本学习《Prototypical Networks for Few-shot Learning》
學(xué)習(xí)報告:基于原型網(wǎng)絡(luò)的小樣本學(xué)習(xí)《Prototypical Networks for Few-shot Learning》
- 一、概述
- 二、方法解析
- 三、實驗
- 3.1 說明
- 3.2 Omniglot分類
- 3.3 miniImageNet分類
- 四、總結(jié)分析
本篇學(xué)習(xí)報告基于論文《Prototypical Networks for Few-shot Learning》,該論文的主要貢獻(xiàn)有兩點:(1)對圖像領(lǐng)域的Few-Shot/Zero-Shot(小樣本/零樣本)任務(wù),應(yīng)用設(shè)計簡單的原型網(wǎng)絡(luò)方法(見第二部分),在通用數(shù)據(jù)集上達(dá)到了較好的實驗效果(見第三部分);(2)對原型網(wǎng)絡(luò)本身進(jìn)行了較為深入的分析,且分析了距離度量方式的選擇對任務(wù)效果的影響(見圖3)。
原文鏈接及開源代碼已置于文末。
一、概述
在小樣本分類問題中,最需要解決的一個問題是數(shù)據(jù)的過擬合問題。由于訓(xùn)練數(shù)據(jù)過少,一般的分類算法會表現(xiàn)出過擬合的現(xiàn)象,從而導(dǎo)致分類結(jié)果與實際結(jié)果有較大的誤差。為了減少因數(shù)據(jù)量過少而導(dǎo)致的過擬合的影響,可以使用基于度量的元學(xué)習(xí)方法,該論文所提出的原型網(wǎng)絡(luò)便屬于這種方法。
該論文為解決小樣本分類問題提出了原型網(wǎng)絡(luò)。在訓(xùn)練集中,對于每一種出現(xiàn)的類別,只給出少量樣本,但分類器能夠很好的泛化到其他沒有出現(xiàn)于訓(xùn)練集中的新類別。原型網(wǎng)絡(luò)會學(xué)習(xí)一個度量空間,在該空間中,可以通過計算與每個類的對應(yīng)原型表示的距離來進(jìn)行分類,距離哪個類的原型表示最近,則被判斷為哪個類。與最近的小樣本學(xué)習(xí)方法相比,該方法反映了一種更簡單的歸納偏差,有利于在這種有限的數(shù)據(jù)范圍內(nèi)使用,并取得優(yōu)異的效果。論文表明一些簡單的設(shè)計決策比最近涉及復(fù)雜體系結(jié)構(gòu)選擇和元學(xué)習(xí)的方法可以產(chǎn)生較好的改進(jìn)效果。
介紹兩類常見的Few-Shot方法:
匹配網(wǎng)絡(luò)(Matching Network):
可以理解為在embedding空間中的加權(quán)最近鄰分類器。模型在訓(xùn)練過程中通過對類標(biāo)簽和樣本的二次采樣來模仿Few-Shot任務(wù)的測試場景,學(xué)習(xí)一個匹配網(wǎng)絡(luò)。該網(wǎng)絡(luò)只在訓(xùn)練集中的關(guān)系基礎(chǔ)上訓(xùn)練,并且直接應(yīng)用于測試集中的關(guān)系。原型網(wǎng)絡(luò)也屬于一種匹配網(wǎng)絡(luò)。實驗和總結(jié)中將對原型網(wǎng)絡(luò)和匹配網(wǎng)絡(luò)的不同之處和分類效果進(jìn)行比較。
Optimization-based meta-learning:
這種方法在訓(xùn)練的過程中的目標(biāo)是學(xué)習(xí)如何通過少量樣本更好的擬合數(shù)據(jù),因此該類方法會針對測試數(shù)據(jù)集對網(wǎng)絡(luò)進(jìn)行調(diào)整。例如,在訓(xùn)練過程中,利用LSTM的網(wǎng)絡(luò)結(jié)構(gòu)學(xué)習(xí)每個訓(xùn)練step所需要的學(xué)習(xí)率。
二、方法解析
在該論文所提出的原型網(wǎng)絡(luò)方法中,需要將樣本投影到一個度量空間,且在這個空間中同類樣本距離較近,異類樣本的距離較遠(yuǎn)。下圖為這個投影空間的示意圖,假如在這個投影空間中,存在三個類別的樣本,且相同類別的樣本間距離較近。為了給一個未標(biāo)注樣本x進(jìn)行標(biāo)注,則將樣本x投影至這個空間并計算x與各個類別的原型距離,離得近的就認(rèn)為x屬于哪個類別。
圖1 投影空間示意圖那么,現(xiàn)在有幾個問題:
1、怎么將這些樣本投影至一個空間且讓同類樣本間距離較近?
2、怎么說明一個類別所在的位置,從而能夠讓未標(biāo)記的樣本計算與類別的距離?
如何將樣本投影至一個空間且讓同類樣本間距離較近?論文中使用的是一個帶參數(shù)φ的嵌入函數(shù)fφ(x),這個函數(shù)可以理解為投影的過程,x表示樣本的特征向量,函數(shù)值表示投影到那個空間后的值,這個嵌入函數(shù)fφ(x)是一個神經(jīng)網(wǎng)絡(luò),參數(shù)φ是需要學(xué)習(xí)的,可以認(rèn)為參數(shù)φ決定了樣本間的位置,所以需要學(xué)習(xí)到一個較好的φ值,讓同類別樣本間距離較近。
此外,還需要考慮如何說明一個類別所在位置,論文中認(rèn)為一個類的位置由這個類所有樣本在投影空間里的平均值決定,類k的原型表示公式如下:
其中Sk表示類k,|Sk|表示類k中樣本的數(shù)量,(xi , yi)為樣本的特征向量和標(biāo)記,此公式實際上為一個求平均的過程。
得到每個類的原型后,就需要根據(jù)樣本與各個類的原型的距離,求一個樣本屬于一個類的概率。因為在訓(xùn)練時這個樣本是已標(biāo)記的,即我們已知類k的原型,已知一個屬于類k的樣本,求此樣本屬于類k的概率,因此我們的目標(biāo)函數(shù)就是求這個概率的最大值。
此公式所表示的意義是,對于樣本x,求它到每個類的距離,然后進(jìn)行歸一化操作得到概率,即x屬于類k的概率。其中d為距離函數(shù),在本篇論文中使用的是歐幾里得距離。在訓(xùn)練過程中,x的標(biāo)簽是已知的。論文中的目標(biāo)函數(shù)為:
一般通過隨機梯度下降方法來求它的最小值,從而收斂后學(xué)到一個好的φ值。可以認(rèn)為,訓(xùn)練結(jié)束后此投影函數(shù)可以將同類的樣本投影到一個相互距離較近的地方。
字符說明:
N:訓(xùn)練集中樣例的數(shù)量
K:訓(xùn)練集中類的數(shù)量
NC:每個Episode中類別的數(shù)量
NS:每個類中支持樣例的數(shù)量
NQ:每個類中查詢樣例的數(shù)量
以下Algorithm 1給出了計算訓(xùn)練集損失J(Φ)的偽代碼
計算過程:為Episode選擇類別 → 選擇支持集 → 選擇訓(xùn)練集 → 計算支持集的原型 → 初始化損失 → 更新?lián)p失
在測試過程中,使用與訓(xùn)練過程中相同的投影函數(shù)方法,求每個類的原型,根據(jù)一個未標(biāo)記的樣本x,求屬于每個類的概率,認(rèn)為概率值大的那個,即為x屬于的類別。
總結(jié)原型網(wǎng)絡(luò)的基本思想:基于集群,找到類的原型,找到合適距離度量方式進(jìn)行分類。
三、實驗
3.1 說明
實驗的數(shù)據(jù)分為支持集和查詢集:
支持集:即訓(xùn)練集,在該論文中由一些已標(biāo)記的樣本組成,比如有N個類,每個類中有M個樣本,則為N-way–M-shot。
查詢集:即測試集,在該論文中由一些已標(biāo)記的樣本和部分未標(biāo)記的樣本組成,后續(xù)實驗結(jié)果表明訓(xùn)練集的way大于測試集的話分類結(jié)果更好(我認(rèn)為這有助于提高模型的泛化性),而shot最好一致(我認(rèn)為是為了保持不同類別樣本的平衡性)。
3.2 Omniglot分類
Omniglot是一個1623個手寫字符分類的數(shù)據(jù)集。每一個字符類別只有20個樣本,不同樣本由不同的人繪制。
該論文使用原形網(wǎng)絡(luò)在Omniglot數(shù)據(jù)集上進(jìn)行實驗,使用歐幾里得距離作為距離度量,分別在1-shot和5-shot進(jìn)行實驗。下圖為某個子集的度量空間的可視化,其中黑色點代表每種類別的原形,紅色代表被錯誤分類的數(shù)據(jù),紅色箭頭的指向為真實的類別。
圖2 Omniglot數(shù)據(jù)集中某個子集的度量空間的t-SNE可視化圖訓(xùn)練episode的設(shè)置為60個類別和每個類別有5個query查詢點。實驗結(jié)果發(fā)現(xiàn)在訓(xùn)練和測試時保持相同的樣本數(shù)據(jù)量(即shot相同)和episode使用更多的類別(即way更大)會使得實驗效果更好。下表展示的是該論文所提出的方法與其他方法在Omniglot數(shù)據(jù)集上的結(jié)果對比。
表1 Omniglot數(shù)據(jù)集分類結(jié)果比較3.3 miniImageNet分類
minilmageNet數(shù)據(jù)集包含100個類別,每個類別中包含600個樣本數(shù)據(jù)。其中64個類別數(shù)據(jù)作為訓(xùn)練集,16個類別數(shù)據(jù)作為驗證集,20個類別數(shù)據(jù)作為測試集。
表2 miniImageNet數(shù)據(jù)集分類結(jié)果比較實驗分別對1-shot和5-shot的設(shè)置進(jìn)行訓(xùn)練episode為5-way和20-way的訓(xùn)練,實驗結(jié)果表明也訓(xùn)練episode中設(shè)置更多的類別,對實驗的結(jié)果有一定的增益效果,這是因為更大的way設(shè)置有助于網(wǎng)絡(luò)進(jìn)行更好的泛化,使得模型在度量空間做出更細(xì)粒度的決策。
還有個比較有意思的實驗結(jié)果:在N-way M-shot問題中的M=1,也就是one-shot的情況下,prototype network實際上等價于matching network;此外,無論是one-shot還是M-shot(M>1),歐氏距離(Euclid.)的效果都要比余弦距離(Cosine)的效果好(如下圖所示),因此本文使用的距離計算公式為歐氏距離。
四、總結(jié)分析
本論文提出的Prototypical Networks(P-net)思想與Matching Networks(M-net)十分相似,兩種網(wǎng)絡(luò)主要有以下不同點:1.使用了不同的距離度量方式,M-net中是余弦距離,P-net中使用的是屬于布雷格曼散度的歐幾里得距離。2.二者在few-shot的場景下不同,在one-shot時等價(one-shot時取得的原型就是支持集中的樣本,相當(dāng)于不用進(jìn)行平均處理)3.網(wǎng)絡(luò)結(jié)構(gòu)上,P-net將編碼層和分類層合一,參數(shù)更少,訓(xùn)練更加方便。論文的實驗部分中也在不同數(shù)據(jù)集上進(jìn)行了兩種網(wǎng)絡(luò)的效果比較,結(jié)果顯示P-net的效果要優(yōu)于M-net。本論文提出的原型網(wǎng)絡(luò)方法雖然結(jié)構(gòu)設(shè)計比較簡單,但是卻能達(dá)到很好的效果,這為我們在解決小樣本分類問題時提供了一種可行的解決思路。
論文地址:https://arxiv.org/pdf/1703.05175.pdf
源代碼:https://github.com/jakesnell/prototypical-networks
總結(jié)
以上是生活随笔為你收集整理的学习报告:基于原型网络的小样本学习《Prototypical Networks for Few-shot Learning》的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: java 线程 设计模式_Java多线程
- 下一篇: allegro如何编辑铜皮