【CV】10分钟理解Focal loss数学原理与Pytorch代码
原文鏈接:https://amaarora.github.io/2020/06/29/FocalLoss.html
原文作者:Aman Arora
Focal loss 是一個(gè)在目標(biāo)檢測(cè)領(lǐng)域常用的損失函數(shù)。最近看到一篇博客,趁這個(gè)機(jī)會(huì),學(xué)習(xí)和翻譯一下,與大家一起交流和分享。
在這篇博客中,我們將會(huì)理解什么是Focal loss,并且什么時(shí)候應(yīng)該使用它。同時(shí)我們會(huì)深入理解下其背后的數(shù)學(xué)原理與pytorch 實(shí)現(xiàn).
什么是Focal loss,它是用來(lái)干嘛的?
為什么Focal loss有效,其中的原理是什么?
Alpha and Gamma?
怎么在代碼中實(shí)現(xiàn)它?
Credits
什么是Focal loss,它是用來(lái)干嘛的?
在了解什么是Focal Loss以及有關(guān)它的所有詳細(xì)信息之前,我們首先快速直觀(guān)地了解Focal Loss的實(shí)際作用。Focal loss最早是 He et al 在論文 Focal Loss for Dense Object Detection 中實(shí)現(xiàn)的。
在這篇文章發(fā)表之前,對(duì)象檢測(cè)實(shí)際上一直被認(rèn)為是一個(gè)很難解決的問(wèn)題,尤其是很難檢測(cè)圖像中的小尺寸對(duì)象。請(qǐng)參見(jiàn)下面的示例,與其他圖片相比,摩托車(chē)的尺寸相對(duì)較小, 所以該模型無(wú)法很好地預(yù)測(cè)摩托車(chē)的存在。
fig-1??bce?
在上圖中,模型無(wú)法預(yù)測(cè)摩托車(chē)的原因是因?yàn)樵撃P褪鞘褂昧薆inary Cross Entropy loss,這種訓(xùn)練目標(biāo)要求模型 對(duì)自己的預(yù)測(cè)真的很有信心。而Focal Loss所做的是,它使模型可以更"放松"地預(yù)測(cè)事物,而無(wú)需80-100%確信此對(duì)象是“某物”。簡(jiǎn)而言之,它給模型提供了更多的自由,可以在進(jìn)行預(yù)測(cè)時(shí)承擔(dān)一些風(fēng)險(xiǎn)。這在處理高度不平衡的數(shù)據(jù)集時(shí)尤其重要,因?yàn)樵谀承┣闆r下(例如癌癥檢測(cè)),即使預(yù)測(cè)結(jié)果為假陽(yáng)性也可接受,確實(shí)需要模型承擔(dān)風(fēng)險(xiǎn)并盡量進(jìn)行預(yù)測(cè)。
因此,Focal loss在樣本不平衡的情況下特別有用。特別是在“對(duì)象檢測(cè)”的情況下,大多數(shù)像素通常都是背景,圖像中只有很少數(shù)的像素具有我們感興趣的對(duì)象。
這是經(jīng)過(guò)Focal loss訓(xùn)練后同一模型對(duì)同樣圖片的預(yù)測(cè)。
fig-2??focal loss prediction
分析這兩者并觀(guān)察其中的差異,可能是個(gè)很好的主意。這將有助于我們對(duì)于Focal loss進(jìn)行直觀(guān)的了解。
那么為什么Focal loss有效,其中的原理是什么?
既然我們已經(jīng)看到了“Focal loss”可以做什么的一個(gè)例子,接下來(lái)讓我們嘗試去理解為什么它可以起作用。下面是了解Focal loss的最重要的一張圖:
fig-3 FL vs CE
在上圖中,“藍(lán)”線(xiàn)代表交叉熵?fù)p失。X軸即“預(yù)測(cè)為真實(shí)標(biāo)簽的概率”(為簡(jiǎn)單起見(jiàn),將其稱(chēng)為pt)。舉例來(lái)說(shuō),假設(shè)模型預(yù)測(cè)某物是自行車(chē)的概率為0.6,而它確實(shí)是自行車(chē), 在這種情況下的pt為0.6。而如果同樣的情況下對(duì)象不是自行車(chē)。則pt為0.4,因?yàn)榇颂幍恼鎸?shí)標(biāo)簽是0,而對(duì)象不是自行車(chē)的概率為0.4(1-0.6)。
Y軸是給定pt后Focal loss和CE的loss的值。
從圖像中可以看出,當(dāng)模型預(yù)測(cè)為真實(shí)標(biāo)簽的概率為0.6左右時(shí),交叉熵?fù)p失仍在0.5左右。因此,為了在訓(xùn)練過(guò)程中減少損失,我們的模型將必須以更高的概率來(lái)預(yù)測(cè)到真實(shí)標(biāo)簽。換句話(huà)說(shuō),交叉熵?fù)p失要求模型對(duì)自己的預(yù)測(cè)非常有信心。但這也同樣會(huì)給模型表現(xiàn)帶來(lái)負(fù)面影響。
深度學(xué)習(xí)模型會(huì)變得過(guò)度自信, 因此模型的泛化能力會(huì)下降.
這個(gè)模型過(guò)度自信的問(wèn)題同樣在另一篇出色的論文 Beyond temperature scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet calibration 被強(qiáng)調(diào)過(guò)。
另外,作為重新思考計(jì)算機(jī)視覺(jué)的初始架構(gòu)的一部分而引入的標(biāo)簽平滑是解決該問(wèn)題的另一種方法。
Focal loss與上述解決方案不同。從比較Focal loss與CrossEntropy的圖表可以看出,當(dāng)使用γ> 1的Focal Loss可以減少“分類(lèi)得好的樣本”或者說(shuō)“模型預(yù)測(cè)正確概率大”的樣本的訓(xùn)練損失,而對(duì)于“難以分類(lèi)的示例”,比如預(yù)測(cè)概率小于0.5的,則不會(huì)減小太多損失。因此,在數(shù)據(jù)類(lèi)別不平衡的情況下,會(huì)讓模型的注意力放在稀少的類(lèi)別上,因?yàn)檫@些類(lèi)別的樣本見(jiàn)過(guò)的少,比較難分。
Focal loss的數(shù)學(xué)定義如下:
Alpha and Gamma?
那么在Focal loss 中的alpha和gamma是什么呢?我們會(huì)將alpha記為α,gamma記為γ。
我們可以這樣來(lái)理解fig3
γ?控制曲線(xiàn)的形狀.?γ的值越大, 好分類(lèi)樣本的loss就越小, 我們就可以把模型的注意力投向那些難分類(lèi)的樣本. 一個(gè)大的?γ?讓獲得小loss的樣本范圍擴(kuò)大了.
同時(shí),當(dāng)γ=0時(shí),這個(gè)表達(dá)式就退化成了Cross Entropy Loss,眾所周知地
定義“ pt”如下,按照其真實(shí)意義:
將上述兩個(gè)式子合并,Cross Entropy Loss其實(shí)就變成了下式。
現(xiàn)在我們知道了γ的作用,那么α是干什么的呢?
除了Focal loss以外,另一種處理類(lèi)別不均衡的方法是引入權(quán)重。給稀有類(lèi)別以高權(quán)重,給統(tǒng)治地位的類(lèi)或普通類(lèi)以小權(quán)重。這些權(quán)重我們也可以用α表示。
alpha-CE
加上了這些權(quán)重確實(shí)幫助處理了類(lèi)別的 不均衡,focal loss的論文報(bào)道:
類(lèi)間不均衡較大會(huì)導(dǎo)致,交叉熵?fù)p失在訓(xùn)練的時(shí)候收到影響。易分類(lèi)的樣本的分類(lèi)錯(cuò)誤的損失占了整體損失的絕大部分,并主導(dǎo)梯度。盡管α平衡了正面/負(fù)面例子的重要性,但它并未區(qū)分簡(jiǎn)單/困難例子。
作者想要解釋的是:
盡管我們加上了α, 它也確實(shí)對(duì)不同的類(lèi)別加上了不同的權(quán)重, 從而平衡了正負(fù)樣本的重要性 ,但在大多數(shù)例子中,只做這個(gè)是不夠的. 我們同樣要做的是減少容易分類(lèi)的樣本分類(lèi)錯(cuò)誤的損失。因?yàn)椴蝗坏脑?huà),這些容易分類(lèi)的樣本就主導(dǎo)了我們的訓(xùn)練.
那么Focal loss 怎么處理的呢,它相對(duì)交叉熵加上了一個(gè)乘性的因子(1 ? pt)**γ,從而像我們上面所講的,降低了易分類(lèi)樣本區(qū)間內(nèi)產(chǎn)生的loss。
再看下Focal loss的表達(dá),是不是清晰了許多。
怎么在代碼中實(shí)現(xiàn)呢?
這是Focal loss在Pytorch中的實(shí)現(xiàn)。
class WeightedFocalLoss(nn.Module):"Non weighted version of Focal Loss"def __init__(self, alpha=.25, gamma=2):super(WeightedFocalLoss, self).__init__()self.alpha = torch.tensor([alpha, 1-alpha]).cuda()self.gamma = gammadef forward(self, inputs, targets):BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')targets = targets.type(torch.long)at = self.alpha.gather(0, targets.data.view(-1))pt = torch.exp(-BCE_loss)F_loss = at*(1-pt)**self.gamma * BCE_lossreturn F_loss.mean()如果你理解了alpha和gamma的意思,那么這個(gè)實(shí)現(xiàn)應(yīng)該都能理解。同時(shí),像文章中提到的一樣,這里是對(duì)BCE進(jìn)行因子的相乘。
Credits
貼上作者的 twitter ,當(dāng)然如果大家有什么問(wèn)題討論,也可以在公眾號(hào)留言。
fig-1?and?fig-2?are from the?Fastai 2018 course?Lecture-09!
未完待續(xù)
今天給大家分享到這里,感謝大家的閱讀和支持,我們會(huì)繼續(xù)給大家分享我們的所思所想所學(xué),希望大家都有收獲!
往期精彩回顧適合初學(xué)者入門(mén)人工智能的路線(xiàn)及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線(xiàn)手冊(cè)深度學(xué)習(xí)筆記專(zhuān)輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專(zhuān)輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專(zhuān)輯獲取一折本站知識(shí)星球優(yōu)惠券,復(fù)制鏈接直接打開(kāi):https://t.zsxq.com/yFQV7am本站qq群1003271085。加入微信群請(qǐng)掃碼進(jìn)群:總結(jié)
以上是生活随笔為你收集整理的【CV】10分钟理解Focal loss数学原理与Pytorch代码的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【Python】用 Python 来实现
- 下一篇: 【机器学习基础】八种应对样本不均衡的策略