GAN半监督学习
概述
GAN的發(fā)明者Ian Goodfellow2016年在Open AI任職期間發(fā)表了這篇論文,其中提到了GAN用于半監(jiān)督學習(semi supervised)的方法。稱為SSGAN。?
作者給出了Theano+Lasagne實現(xiàn)。本文結合源碼對這種方法的推導和實現(xiàn)進行講解。1
半監(jiān)督學習
考慮一個分類問題。?
如果訓練集中大部分樣本沒有標記類別,只有少部分樣本有標記。則需要用半監(jiān)督學習(semi-supervised)方法來訓練一個分類器。
wiki上的這張圖很好地說明了無標記樣本在半監(jiān)督學習中發(fā)揮作用:?
如果只考慮有標記樣本(黑白點),純粹使用監(jiān)督學習。則得到垂直的分類面。?
考慮了無標記樣本(灰色點)之后,我們對樣本的整體分布有了進一步認識,能夠得到新的、更準確的分類面。
核心理念
在半監(jiān)督學習中運用GAN的邏輯如下。
- 無標記樣本沒有類別信息,無法訓練分類器;
- 引入GAN后,其中生成器(Generator)可以從隨機信號生成偽樣本;
- 相比之下,原有的無標記樣本擁有了人造類別:真。可以和偽樣本一起訓練分類器。?
舉個通俗的例子:就算沒人教認字,多練練分辨“是不是字”也對認字有好處。有粗糙的反饋,也比沒有反饋強。
原理
框架
GAN中的兩個核心模塊是生成器(Generator)和鑒別器(Discriminator)。這里用分類器(Classifier)代替了鑒別器。?
訓練集中包含有標簽樣本xlxl和無標簽樣本xuxu。?
生成器從隨機噪聲生成偽樣本IfIf。?
分類器接受樣本II,對于KK類分類問題,輸出K+1K+1維估計ll,再經(jīng)過softmax函數(shù)得到概率pp:其前KK維對應原有KK個類,最后一維對應“偽樣本”類。?
pp的最大值位置對應為估計標簽yy。
三種誤差
整個系統(tǒng)涉及三種誤差。
對于訓練集中的有標簽樣本,考察估計的標簽是否正確。即,計算分類為相應的概率:?
對于訓練集中的無標簽樣本,考察是否估計為“真”。即,計算不估計為K+1K+1類的概率:?
對于生成器產(chǎn)生的偽樣本,考察是否估計為“偽”。即,計算估計為K+1K+1類的概率:?
推導
考慮softmax函數(shù)的一個特性:?
即,如果輸入各維減去同一個數(shù),softmax結果不變。?
于是,可以令 l→l?lK+1l→l?lK+1 ,有 lK+1=0lK+1=0 , p=softmax(l)p=softmax(l) 保持不變。
期望號略去不寫,利用explK+1=1,exp?lK+1=1,后兩種代價變?yōu)?#xff1a;?
上述推導可以讓我們省去lK+1lK+1,讓分類器仍然輸出K維的估計ll。
對于第一個代價,由于分類器輸入必定來自前K類,所以可以直接使用ll的前K維:?
引入兩個函數(shù),使得書寫更為簡潔:
LSE(x)=ln[∑j=1expxj]LSE(x)=ln?[∑j=1exp?xj]softplus(x)=ln(1+expx)softplus(x)=ln?(1+exp?x)三個誤差:?
優(yōu)化目標
對于分類器來說,希望上述誤差盡量小。引入權重ww,得到分類器優(yōu)化目標:?
對于生成器來說,希望其輸出的偽樣本能夠騙過分類器。生成器優(yōu)化目標與分類器的第三項相反:?
實驗
本文的實驗包含三個圖像分類問題。分類器接受圖像xx,輸出KK類分類結果ll。生成器從均勻分布的噪聲zz生成一張圖像xx。
MNIST
10分類問題,圖像為28*28灰度。
生成器是一個3層線性網(wǎng)絡:?
分類器是一個6層線性網(wǎng)絡:?
訓練樣本60K個,測試樣本10K個。?
選擇不同數(shù)量的訓練樣本給予標記,考察測試樣本中錯誤個數(shù)。使用不同隨機數(shù)種子重復10次:
| 占比 | 0.033% | 0.083% | 0.17% | 0.33% |
| 錯誤個數(shù) | 1677±452 | 221±136 | 93±6.5 | 90±4.2 |
Cifar10
10分類問題,圖像為32*32彩色。
生成器是一個4層反卷積網(wǎng)絡:?
分類器是一個9層卷積網(wǎng)絡:?
訓練樣本50K個,測試樣本10K個。?
選擇不同數(shù)量的訓練樣本給予標記,考察測試樣本中錯誤個數(shù)。使用不同的測試/訓練分割重復10次:
| 占比 | 2% | 4% | 8% | 16% |
| 錯誤個數(shù) | 21.83±2.01 | 19.61±2.09 | 18.63±2.32 | 17.72±1.82 |
SVHN
10分類問題,圖像為32*32彩色。
生成器(上)以及分類器(下)和CIFAR10的結構非常類似。?
訓練樣本73K,測試樣本26K。?
選擇不同數(shù)量的訓練樣本給予標記,考察測試樣本中錯誤個數(shù)。使用不同的測試/訓練分割重復10次:
| 占比 | 0.68% | 1.4% | 2.7% |
| 錯誤個數(shù) | 18.84±4.8 | 8.11±1.3 | 6.16±0.58 |
總結