置信学习:让样本中的“脏数据“原形毕露
在實(shí)際工作中,你是否遇到過這樣一個問題或痛點(diǎn):無論是通過哪種方式獲取的標(biāo)注數(shù)據(jù),數(shù)據(jù)標(biāo)注質(zhì)量可能不過關(guān),存在一些錯誤?亦或者是數(shù)據(jù)標(biāo)注的標(biāo)準(zhǔn)不統(tǒng)一、存在一些歧義?特別是badcase反饋回來,發(fā)現(xiàn)訓(xùn)練集標(biāo)注的居然和badcase一樣?如下圖所示,QuickDraw、MNIST和Amazon Reviews數(shù)據(jù)集中就存在錯誤標(biāo)注。
為了快速迭代,大家是不是常常直接人工去清洗這些“臟數(shù)據(jù)”?(筆者也經(jīng)常這么干~)。但數(shù)據(jù)規(guī)模上來了咋整?有沒有一種方法能夠自動找出哪些錯誤標(biāo)注的樣本呢?基于此,本文嘗試提供一種可能的解決方案——置信學(xué)習(xí)。
本文的組織架構(gòu)是:
01
置信學(xué)習(xí)的定義
那什么是置信學(xué)習(xí)呢?這個概念來自一篇由MIT和Google聯(lián)合提出的paper:《Confident Learning: Estimating Uncertainty in Dataset Labels[1] 》。論文提出的置信學(xué)習(xí)(confident learning,CL)是一種新興的、具有原則性的框架,以識別標(biāo)簽錯誤、表征標(biāo)簽噪聲并應(yīng)用于帶噪學(xué)習(xí)(noisy label learning)。
筆者注:筆者乍一聽置信學(xué)習(xí)挺陌生的,但回過頭來想想,好像干過類似的事情,比如:在某些場景下,對訓(xùn)練集通過交叉驗(yàn)證來找出一些可能存在錯誤標(biāo)注的樣本,然后交給人工去糾正。此外,神經(jīng)網(wǎng)絡(luò)的成功通常建立在大量、干凈的數(shù)據(jù)上,標(biāo)注錯誤過多必然會影響性能表現(xiàn),帶噪學(xué)習(xí)可是一個大的topic,有興趣可參考這些文獻(xiàn):github.com/subeeshvasu/。廢話不說,首先給出這種置信學(xué)習(xí)框架的優(yōu)勢:
-
最大的優(yōu)勢:可以用于發(fā)現(xiàn)標(biāo)注錯誤的樣本!
-
無需迭代,開源了相應(yīng)的python包,方便快速使用!在ImageNet中查找訓(xùn)練集的標(biāo)簽錯誤僅僅需要3分鐘!
-
可直接估計噪聲標(biāo)簽與真實(shí)標(biāo)簽的聯(lián)合分布,具有理論合理性。
-
不需要超參數(shù),只需使用交叉驗(yàn)證來獲得樣本外的預(yù)測概率。
-
不需要做隨機(jī)均勻的標(biāo)簽噪聲的假設(shè)(這種假設(shè)在實(shí)踐中通常不現(xiàn)實(shí))。
-
與模型無關(guān),可以使用任意模型,不像眾多帶噪學(xué)習(xí)與模型和訓(xùn)練過程強(qiáng)耦合。
02
置信學(xué)習(xí)開源工具:cleanlab
論文最令人驚喜的一點(diǎn)就是作者這個置信學(xué)習(xí)框架進(jìn)行了開源,并命名為cleanlab,我們可以pip install cleanlab使用,具體文檔說明在這里cleanlab文檔說明。
from cleanlab.pruning import get_noise_indices # 輸入 # s:噪聲標(biāo)簽 # psx: n x m 的預(yù)測概率概率,通過交叉驗(yàn)證獲得 ordered_label_errors = get_noise_indices(s=numpy_array_of_noisy_labels,psx=numpy_array_of_predicted_probabilities,sorted_index_method='normalized_margin', # Orders label errors)我們來看看cleanlab在MINIST數(shù)據(jù)集中找出的錯誤樣本吧,是不是感覺很牛~
如果你不只是想找到錯誤標(biāo)注的樣本,還想把這些標(biāo)注噪音clean掉之后重新繼續(xù)學(xué)習(xí),那3行codes也可以搞定,這時候連交叉驗(yàn)證都省了~:
from cleanlab.classification import LearningWithNoisyLabels from sklearn.linear_model import LogisticRegression# 其實(shí)可以封裝任意一個你自定義的模型. lnl = LearningWithNoisyLabels(clf=LogisticRegression()) lnl.fit(X=X_train_data, s=train_noisy_labels) # 對真實(shí)世界進(jìn)行驗(yàn)證. predicted_test_labels = lnl.predict(X_test)?
筆者注:上面雖然只給出了CV領(lǐng)域的例子,但置信學(xué)習(xí)也適用于NLP啊~此外,cleanlab可以封裝任意一個你自定義的模型,以下機(jī)器學(xué)習(xí)框架都適用:scikit-learn, PyTorch, TensorFlow, FastText。
03
置信學(xué)習(xí)的3個步驟
置信學(xué)習(xí)開源工具cleanlab操作起來比較容易,但置信學(xué)習(xí)背后也是有著充分的理論支持的。事實(shí)上,一個完整的置信學(xué)習(xí)框架,需要完成以下三個步驟 ( 如圖1所示 ):
-
Count:估計噪聲標(biāo)簽和真實(shí)標(biāo)簽的聯(lián)合分布;
-
Clean:找出并過濾掉錯誤樣本;
-
Re-Training:過濾錯誤樣本后,重新調(diào)整樣本類別權(quán)重,重新訓(xùn)練;
圖1 置信學(xué)習(xí)框架
下面對上述3個步驟進(jìn)行詳細(xì)闡述:
1. Count:估計噪聲標(biāo)簽和真實(shí)標(biāo)簽的聯(lián)合分布
我們定義噪聲標(biāo)簽為??,即經(jīng)過初始標(biāo)注(也許是人工標(biāo)注)、但可能存在錯誤的樣本;定義真實(shí)標(biāo)簽為?,但事實(shí)上我們并不會獲得真實(shí)標(biāo)簽,通常可通過交叉驗(yàn)證對真實(shí)標(biāo)簽進(jìn)行估計。此外,定義樣本總數(shù)為??,類別總數(shù)為??。
為了估計聯(lián)合分布,共需要4步:
step 1?:?交叉驗(yàn)證
-
首先需要通過對數(shù)據(jù)集集進(jìn)行交叉驗(yàn)證計算第??樣本在第??個類別下的概率??;
-
然后計算每個人工標(biāo)定類別??下的平均概率??作為置信度閾值;
-
最后對于樣本??,其真實(shí)標(biāo)簽??為??個類別中的最大概率??,并且??;
step 2:?
計算計數(shù)矩陣??(類似于混淆矩陣),如圖1中的?意味著,人工標(biāo)記為dog但實(shí)際為fox的樣本為40個。具體的操作流程如圖2所示:
圖2 計數(shù)矩陣C計算流程
step 3?: 標(biāo)定計數(shù)矩陣
目的就是為了讓計數(shù)總和與人工標(biāo)記的樣本總數(shù)相同。計算公式如下面所示,其中??為人工標(biāo)記標(biāo)簽??的樣本總個數(shù):
?①
step 4?:?
估計噪聲標(biāo)簽?和真實(shí)標(biāo)簽的聯(lián)合分布,可通過下式求得:
?②
看到這里,也許你會問為什么要估計這個聯(lián)合分布呢?其實(shí)這主要是為了下一步方便我們?nèi)lean噪聲數(shù)據(jù)。此外,這個聯(lián)合分布其實(shí)能充分反映真實(shí)世界中噪聲 ( 錯誤 ) 標(biāo)簽和真實(shí)標(biāo)簽的分布,隨著數(shù)據(jù)規(guī)模的擴(kuò)大,這種估計方法與真實(shí)分布越接近 ( 原論文中有著嚴(yán)謹(jǐn)?shù)淖C明,由于公式推導(dǎo)繁雜這里不再贅述,有興趣的同學(xué)可以詳細(xì)閱讀原文~,后文的圖7也有相關(guān)實(shí)驗(yàn)進(jìn)行證明 )。
看到這里,也許你還感覺公式好麻煩,那下面我們通過一個具體的例子來展示上述計算過程:
step 1 :?通過交叉驗(yàn)證獲取第??樣本在第??個類別下的概率??;為說明問題,這里假設(shè)共10個樣本、2個類別,每個類別有5個樣本。經(jīng)過計算每個人工標(biāo)簽類別??下的平均概率??分別為:??.
圖3 P[i][j]和t[j]計算
step2:根據(jù)圖2的計算流程,我們得到計數(shù)矩陣??為:
圖4 計數(shù)矩陣C計算
step3:標(biāo)定后的計數(shù)矩陣?為 ( 計數(shù)總和與人工標(biāo)記的樣本總數(shù)相同 ),將原來的樣本總數(shù)進(jìn)行加權(quán)即可,以??為例,根據(jù)公式①,其計算為??):
step4:聯(lián)合分布?為:( 根據(jù)公式②直接進(jìn)行概率歸一化即可 )
圖5 聯(lián)合分布Q計算
2. Clean:找出并過濾掉錯誤樣本
在得到噪聲標(biāo)簽和真實(shí)標(biāo)簽的聯(lián)合分布??,論文共提出了5種方法過濾錯誤樣本。
Method 1:,選取??的樣本進(jìn)行過濾,即選取??最大概率對應(yīng)的下標(biāo)??與人工標(biāo)簽不一致的樣本。
Method 2:,選取構(gòu)造計數(shù)矩陣??過程中、進(jìn)入非對角單元的樣本進(jìn)行過濾。
Method 3:Prune by Class ( PBC ),即對于人工標(biāo)記的每一個類別??,選取??個樣本過濾,并按照最低概率??排序。
Method 4:Prune by Noise Rate ( PBNR ),對于計數(shù)矩陣?的非對角單元,選取??個樣本進(jìn)行過濾,并按照最大間隔??排序。
Method 5:C+NR,同時采用Method 3和Method 4。
我們?nèi)匀灰詧D3給出的示例進(jìn)行說明:
Method 1:過濾掉i=2,3,4,8,9共5個樣本;
Method 2:進(jìn)入到計數(shù)矩陣非對角單元的樣本分別為i=3,4,9,將這3個樣本過濾;
Method 3:對于類別0,選取??個樣本過濾,按照最低概率排序,選取i=2,3,4;對于類別1,選取??個樣本過濾,按照最低概率排序選取i=9;綜上,共過濾i=2,3,4,9共4個樣本;
Method 4:對于非對角單元??選取i=2,3,4過濾,對??選取i=9過濾。
上述這些過濾樣本的方法在cleanlab也有提供,我們只要提供2個輸入、1行code即可clean錯誤樣本:
import cleanlab # 輸入 # s:噪聲標(biāo)簽 # psx: n x m 的預(yù)測概率概率,通過交叉驗(yàn)證獲得 # Method 3:Prune by Class (PBC) baseline_cl_pbc = cleanlab.pruning.get_noise_indices(s, psx, prune_method='prune_by_class',n_jobs=1) # Method 4:Prune by Noise Rate (PBNR) baseline_cl_pbnr = cleanlab.pruning.get_noise_indices(s, psx, prune_method='prune_by_noise_rate',n_jobs=1) # Method 5:C+NR baseline_cl_both = cleanlab.pruning.get_noise_indices(s, psx, prune_method='both',n_jobs=1)3. Re-Training:過濾錯誤樣本后,重新訓(xùn)練
在過濾掉錯誤樣本后,根據(jù)聯(lián)合分布??將每個類別i下的損失權(quán)重修正為:??,其中??.然后采取Co-Teaching[2]框架進(jìn)行。
圖6 Co-teaching
如圖6所示,Co-teaching的基本假設(shè)是認(rèn)為noisy label的loss要比clean label的要大,于是它并行地訓(xùn)練了兩個神經(jīng)網(wǎng)絡(luò)A和B,在每一個Mini-batch訓(xùn)練的過程中,每一個神經(jīng)網(wǎng)絡(luò)把它認(rèn)為loss比較小的樣本,送給它其另外一個網(wǎng)絡(luò),這樣不斷進(jìn)行迭代訓(xùn)練。
04
實(shí)驗(yàn)結(jié)果
上面我們介紹完成置信學(xué)習(xí)的3個步驟,本小節(jié)我們來看看這種置信學(xué)習(xí)框架在實(shí)踐中效果如何?在正式介紹之前,我們首先對稀疏率進(jìn)行定義:稀疏率為聯(lián)合分布矩陣、非對角單元中0所占的比率,這意味著真實(shí)世界中,總有一些樣本不會被輕易錯標(biāo)為某些類別,如老虎圖片不會被輕易錯標(biāo)為汽車。
圖7 真實(shí)聯(lián)合分布和估計聯(lián)合分布
圖7給出了CIFAR-10中,噪聲率為40%和稀疏率為60%情況下,真實(shí)聯(lián)合分布和估計聯(lián)合分布之間的比較,可以看出二者之間很接近,可見論文提出的置信學(xué)習(xí)框架用來估計聯(lián)合分布的有效性。
圖8 不同置信學(xué)習(xí)方法的比較
上圖給出了CIFAR-10中不同噪聲情況和稀疏性情況下,置信學(xué)習(xí)與其他SOTA方法的比較。例如在40%的噪聲率下,置信學(xué)習(xí)比之前SOTA方法Mentornet的準(zhǔn)確率平均提高34%。
圖9 置信學(xué)習(xí)發(fā)現(xiàn)的 ImageNet標(biāo)簽問題
論文還將提出置信學(xué)習(xí)框架應(yīng)用于真實(shí)世界的ImageNet數(shù)據(jù)集,利用CL:PBNR找出的TOP32標(biāo)簽問題如圖9所示,置信學(xué)習(xí)除了可以找出標(biāo)注錯誤的樣本 ( 紅色部分 ),也可以發(fā)現(xiàn)多標(biāo)簽問題 ( 藍(lán)色部分,圖像可以有多個標(biāo)簽 ),以及本體論問題:綠色部分,包括"是" ( 比如:將浴缸標(biāo)記為桶 ) 或"有" ( 比如:示波器標(biāo)記為CRT屏幕 ) 兩種關(guān)系。
圖10 不同置信學(xué)習(xí)方法和隨機(jī)去除的對比
圖10給出了分別去除20%,40%…,100%估計錯誤標(biāo)注的樣本后訓(xùn)練的準(zhǔn)確性,最多移除200K個樣本。可以看出,當(dāng)移除小于100K個訓(xùn)練樣本時,置信學(xué)習(xí)框架使得準(zhǔn)確率明顯提升,并優(yōu)于隨機(jī)去除。
05
總結(jié)
本文介紹了一種用來刻畫noisy label、找出錯誤標(biāo)注樣本的方法——置信學(xué)習(xí),是弱監(jiān)督學(xué)習(xí)和帶噪學(xué)習(xí)的一個分支。
置信學(xué)習(xí)直接估計噪聲標(biāo)簽和真實(shí)標(biāo)簽的聯(lián)合分布,而不是修復(fù)噪聲標(biāo)簽或者修改損失權(quán)重。
置信學(xué)習(xí)開源包c(diǎn)leanlab可以很快速的幫你找出那些錯誤樣本!可在分鐘級別之內(nèi)找出錯誤標(biāo)注的樣本。
總結(jié)
以上是生活随笔為你收集整理的置信学习:让样本中的“脏数据“原形毕露的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: MySQL优化CPU消耗
- 下一篇: 机器学习教程汇总