多类SVM的损失函数
原文:Multi-class SVM Loss?
作者:?Adrian Rosebrock?
翻譯:?KK4SBB?責編:何永燦
from:?http://geek.csdn.net/news/detail/101547
幾個星期之前,我們討論了線性分類和參數化學習的概念。這類學習方法使我們能夠輸入一組數據和類別標簽,然后從中學到一個從輸入值到預測值的映射關系,而我們只需要定義一組參數并優化這些參數。
我們本篇線性分類器教程主要關注評分函數的概念和它的用法。但是,為了真的“學會”輸入值和類別標簽的映射關系,我們需要討論下面兩個重要的概念:
- 損失函數
- 優化方法
在本周和下周的文章中,我們會討論兩類常見的損失函數,它們在機器學習、神經網絡和深度學習算法中都被應用:
- 多類SVM損失
- 交叉熵(用于Softmax分類器/多項式邏輯回歸)
接下來,我們就討論多類SVM損失。
多類SVM損失
用最簡單的方式來解釋,損失函數就是用來衡量一個預測器在對輸入數據進行分類預測時的質量好壞。
損失值越小,分類器的效果越好,越能反映輸入數據與輸出類別標簽的關系(雖然我們的模型有時候會過擬合——這是由于訓練數據被過度擬合,導致我們的模型失去了泛化能力)。
相反,損失值越大,我們需要花更多的精力來提升模型的準確率。就參數化學習而言,這涉及到調整參數,比如需要調節權重矩陣W或偏置向量B,以提高分類的精度。確切地說,我們如何去更新這些參數屬于優化問題,我們這一系列的教程的后續篇幅將會覆蓋這些話題。
多類SVM損失背后的數學問題
在閱讀完Python的線性分類教程之后,你會發現我們選用的分類器是線性支持向量機(linear SVM)。
上一篇教程著重介紹了評分函數f的概念,它把我們的特征向量映射為數值型的類別標簽。如其名稱所示,線性SVM采用簡單的線性映射:
現在我們有了評分/映射函數f,我們需要確定這個函數預測的質量(給定權重矩陣W和偏置向量b)是“好”還是“壞”。
為了完成這一目標,需要定義一個損失函數。接著,我們就來給損失函數下一個定義。
基于之前的線性分類器教程,我們知道當前有一個特征向量矩陣x —— 這些特征向量可以從顏色直方圖中獲取,也可以是HOG特征,或者甚至是原始像素值。
無論我們如何選擇量化圖像,我們都能從圖像數據集中抽取出一個特征矩陣x。然后,我們可以用xi獲取某張圖片的第i維特征,也就是x的第i個特征向量。
同樣的,我們也有一個向量y,存儲了每個x的類別標簽。這些y值是我們的參照標簽,正是我們希望評分函數能夠準確預測的標簽值。就像我們可以用xi得到某個特征向量,我們也可以用yi讀取第i個類別標簽。
為了簡化,我們將評分函數簡寫為s:
第i個數據的第j類預測得分值可以表示為:
按照上述定義,我們將它代入公式,得到了hinge損失函數:
注意:我先故意略過正則化參數項。在后續的文章中,當我們理解了損失函數,我會再來介紹正則化。
那么,上面那個方程究竟有什么用途?
我很高興你能提出這樣的問題。
簡單來說,hinge損失函數將預測不正確的類別()累加,然后將我們評分函數s在第j類(不正確類別)的輸出值與在第yi類的輸出值比較。
然后應用max函數,使得函數的輸出值不小于0 —— 這一點非常重要,因而輸出不會出現負值。
若Li=0,說明給定的數據xi被正確分類了(我在后續的章節中會舉一個例子)。
當把損失值推廣到整個訓練數據集,我們對所有的Li取平均數:
此外,常用的損失函數還有平方hinge損失:
平方項對損失值的懲罰力度更大。
至于選用何種損失函數,這需要視數據集而定。標準的hinge損失函數比較常見,但某些數據集可能使用平方項能取得更好的精度 —— 總之,這是一個需要你交叉驗證的超參數。
多類SVM損失示例
現在,再來討論hinge損失和平方hinge損失的數學原理,以下面的問題為例。
我們再一次選用Kaggle的狗vs.貓數據集,即判斷指定圖片里包含了貓還是狗。
這個數據集中只包含了兩種可能的類別標簽,因此屬于二分類問題,可以用標準的二項SVM損失函數求解。也就是說,我們仍然使用多類SVM損失,所以我們可以有一個成功實踐的例子。然后,我會擴展示例來處理三種類別的問題。
首先,看看下面的圖片,圖片是來自“狗vs.貓”數據集的兩個訓練樣本:
給定任意的權重矩陣W和偏置向量b,f(x,W)=Wx+b函數的輸出分數如上表所示。分數值越大,說明我們的評分函數對預測結果的置信度越高。
我們先來計算“狗”類的損失值Li。假設一個二分類問題,這就非常容易:
>>> max(0, 1.33 - 4.26 + 1) 0 >>>請注意“狗”的損失值為啥等于零 —— 意思是正確地預測了狗的類別。快速地回顧上述圖1所示的內容:“狗”的分值大于“貓”的分值。
同樣的,我們對第二張圖像采取相同的做法,這張圖片包含了一只貓:
>>> max(0, 3.76 - (-1.2) + 1) 5.96 >>>損失函數的輸出值大于零,意味著我們的預測結果不正確。
我們計算兩張圖片的損失值的均值作為整體損失值:
>>> (0 + 5.96) / 2 2.98 >>對于二分類問題,計算過程非常簡單,那對于三分類問題呢?過程會變得復雜嗎?
事實上,并沒有復雜 —— 下圖是一個三類問題的示例,我新加入了一個類別“馬”:
再次計算“狗”這一類的損失值:
>>> max(0, 1.49 - (-0.39) + 1) + max(0, 4.21 - (-0.39) + 1) 8.48 >>>請注意我們是如何將求和部分擴展到兩項計算的 —— 分別計算“狗”類的預測得分與“貓”類和”馬”類分值的差。
同樣的,計算”貓”這一類的損失值:
>>> max(0, -4.61 - 3.28 + 1) + max(0, 1.46 - 3.28 + 1) 0 >>>最后,計算”馬”這一類的損失值:
>>> max(0, 1.03 - (-2.27) + 1) + max(0, -2.37 - (-2.27) + 1) 5.199999999999999 >>>因此,整體損失值是:
>>> (8.48 + 0.0 + 5.2) / 3 4.56 >>>正如你所看到的,它們都適用同樣的原則 —— 只要記住在擴展類別數目的同時,求和的項數也要擴展。
測驗:根據上面三類的損失值判斷,哪一類是正確的預測值?
我需要動手實現多類SVM損失值計算嗎?
如果你愿意,也可以動手實現hinge和平方hinge損失值 —— 但這主要還是出于學習的目的。
你幾乎可以在所有的機器學習/深度學習庫里找到hinge損失和平方hinge損失的實現,比如scikit-learn, Keras, Caffe等等。
總結
今天我們討論了多類SVM損失的概念。給定一個評分函數(將輸入數據映射到輸出的類別標簽),我們的損失函數可以用來定量評判評分函數預測正確類別標簽質量的“好”與“壞”。
損失值越小,我們的預測越準確(但存在過擬合的風險,映射函數過于擬合了輸入數據)。
相反,損失值越大,我們的預測結果越不準確,因此需要繼續優化參數W和b —— 當我們更深入地理解損失函數之后,后續文章會介紹優化方法。
理解“損失”的概念以及它在機器學習和深度學習算法中的應用之后,我們仔細研究了兩類損失函數:
- hinge損失函數
- 平方hinge損失函數
通常,hinge損失更常見 —— 但仍然需要調優分類器的超參數來判斷哪種損失函數更適合你的數據集。
總結
以上是生活随笔為你收集整理的多类SVM的损失函数的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: OpenMP: OpenMP编程指南
- 下一篇: 缓存架构设计细节二三事