GAN学习
開始學習GAN生成對抗網(wǎng)絡相關知識,將要點和心得總結于此。
文章目錄
- 起源
- 主要思想
- 特點
- 訓練技巧
- 應用場景
- 其他
- GAN及其改進
- GAN
- DCGAN
- WGAN和WGAN-gp
- LSGAN
- cGAN
- pix2pix
- CycleGAN
起源
GAN,全名 Generative Adversarial Networks,即生成式對抗網(wǎng)絡,是2014年Lan Goodfellow的論文《Generative Adversarial Nets》中提出的一種新的方法,是一種無監(jiān)督學習模型,通過學習樣本分布讓算法生成類似分布的圖片。
主要思想
GAN的主要靈感來源于博弈論中零和博弈的思想。通過生成網(wǎng)絡G(Generator)和判別網(wǎng)絡D(Discriminator)不斷博弈,進而使G學習到數(shù)據(jù)的分布,根據(jù)一定的映射規(guī)則從一段隨機數(shù)中生成逼真的圖像。
G是一個生成網(wǎng)絡,輸入為一個隨機的噪聲,輸出為的生成圖像。
D是一個判別網(wǎng)絡,輸入為一張圖片,輸出為真實圖片的概率,范圍為0-1。
訓練過程中,G的目標就是盡量生成真實的圖片去欺騙D。而D的目標就是盡量辨別出G生成的假圖像。這樣,G和D構成了一個動態(tài)的“博弈過程”,最終的平衡點即納什均衡點。
G的梯度更新信息來自判別器D,而不是來自數(shù)據(jù)樣本。
特點
GAN 的優(yōu)點:
GAN的缺點:
模式崩潰(model collapse):生成的數(shù)據(jù)多樣性不足。原GAN論文中提出的loss函數(shù)經(jīng)過變換后為KL散度項,KL散度不具有對稱性,即KL(A|B)≠KL(B|A)。
故在優(yōu)化過程中l(wèi)oss對于兩種錯誤的懲罰不同,第一種錯誤表示樣本中包含的數(shù)據(jù)沒有被生成,即缺乏多樣性,懲罰微小;第二種錯誤表示生成的數(shù)據(jù)在樣本中不存在 ,即缺乏準確性,懲罰巨大。由于不平衡的懲罰導致生成器寧可多生成一些重復但是正確的樣本,也不愿意去生成多樣性的樣本,因為那樣一不小心就會產(chǎn)生第二種錯誤。這種現(xiàn)象就是大家常說的collapse mode。
訓練技巧
應用場景
GAN應用匯總
常見GAN變體及實現(xiàn)
其他
為什么GAN中的優(yōu)化器不常用SGD
為什么GAN不適合處理文本數(shù)據(jù)
GAN及其改進
GAN
如上圖所示,生成對抗網(wǎng)絡會訓練并更新判別分布(即 D,藍色的虛線),更新判別器后就能將數(shù)據(jù)真實分布(黑點組成的線)從生成分布 P_g(G)(綠色實線)中判別出來。下方的水平線代表采樣域 Z,其中等距線表示 Z 中的樣本為均勻分布,上方的水平線代表真實數(shù)據(jù) X 中的一部分。向上的箭頭表示映射 x=G(z) 如何對噪聲樣本(均勻采樣)施加一個不均勻的分布 P_g。(a)考慮在收斂點附近的對抗訓練:P_g 和 P_data 已經(jīng)十分相似,D 是一個局部準確的分類器。(b)在算法內(nèi)部循環(huán)中訓練 D 以從數(shù)據(jù)中判別出真實樣本,該循環(huán)最終會收斂到 D(x)=P_data(x)/(P_data(x)+P_g(x))。(c)隨后固定判別器并訓練生成器,在更新 G 之后,D 的梯度會引導 G(z)流向更可能被 D 分類為真實數(shù)據(jù)的方向。(d)經(jīng)過若干次訓練后,如果 G 和 D 有足夠的復雜度,那么它們就會到達一個均衡點。這個時候 P_g=P_data,即生成器的概率密度函數(shù)等于真實數(shù)據(jù)的概率密度函數(shù),也即生成的數(shù)據(jù)和真實數(shù)據(jù)是一樣的。在均衡點上 D 和 G 都不能得到進一步提升,并且判別器無法判斷數(shù)據(jù)到底是來自真實樣本還是偽造的數(shù)據(jù),即 D(x)= 1/2。
具體算法實現(xiàn)
參考資料:
機器之心GitHub項目:GAN完整理論推導與實現(xiàn)
KL散度、JS散度以及交叉熵對比
Generative Adversarial Nets(譯)
DCGAN
將GAN與CNN相結合,將原論文中的MLP網(wǎng)絡更換為CNN網(wǎng)絡,改善了對圖片的生成與判別效果。
主要貢獻是:
為GAN的訓練提供了一個很好的網(wǎng)絡拓撲結構。
表明生成的特征具有向量的計算特性。
使用的CNN結構如下
判別器幾乎是和生成器對稱的。整個網(wǎng)絡沒有pooling層和上采樣層,實際上是使用了帶步長(fractional-strided)的卷積代替了上采樣,以增加訓練的穩(wěn)定性。
DCGAN能改進GAN訓練穩(wěn)定的原因主要有:
- 使用步長卷積代替上采樣層,卷積在提取圖像特征上具有很好的作用,并且使用卷積代替全連接層。
- 生成器G和判別器D中幾乎每一層都使用batchnorm層,將特征層的輸出歸一化到一起,加速了訓練,提升了訓練的穩(wěn)定性。(生成器的最后一層和判別器的第一層不加batchnorm)
- 在判別器中使用leaky-ReLU激活函數(shù),而不是ReLU,防止梯度稀疏,生成器中仍然采用ReLU,但是輸出層采用tanh
- 使用adam優(yōu)化器訓練,學習率推薦為0.0002
參考資料:
DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN原理總結及對比
DCGAN在TF上實現(xiàn)
WGAN和WGAN-gp
Wasserstein距離
WGAN
GAN定義的損失函數(shù)具有一定的缺陷,具體表現(xiàn)為:訓練不穩(wěn)定,存在collapse mode情況,并且在訓練目標不明確的問題。針對這些缺陷,WGAN提出了用Wasserstein距離替代JS散度,用于計算真實樣本與生成樣本之間的差異。將判別器的作用從判斷樣本是否為真轉化為計算真實樣本與生成樣本之間Wasserstein距離,通過不斷減少這個距離可以優(yōu)化生成器。Wasserstein距離也是一個明確的指示標志,表面了當前模型訓練情況,Wasserstein距離越小說明生成樣本與真實樣本越接近。此外,由于Wasserstein距離具有對稱性,還(基本上)解決了collapse mode情況。
具體算法如下
WGAN-gp
WGAN的梯度裁剪的方法具有一定的弊端,會產(chǎn)生下圖左側情況
容易引起梯度消失或者梯度爆炸情況,因此使用懲罰系數(shù)代替梯度裁剪。具體方法是對損失函數(shù)加入梯度懲罰項,當梯度大于1時進行懲罰,保證Lipschitz連續(xù)性限制。該懲罰項的梯度位置在真實樣本與生成樣本中的連線中隨機采樣某一點,然后計算D(x)并求梯度,最后計算與1的距離。具體形式為:
令人拍案叫絕的Wasserstein GAN
W-GAN系 (Wasserstein GAN、 Improved WGAN)
WGAN在TF上實現(xiàn)
LSGAN
將GAN的損失函數(shù)更換為最小二乘損失函數(shù),其目的與WGAN類似,即JS散度具有不對稱性和范圍(0-1),因此不能拉近真實分布和生成分布之間的距離,使用最小二乘可以將圖像的分布盡可能的接近決策邊界。LSGAN損失函數(shù)定義如下:
minDJ(D)=minD12Ex~Pr[D(x)?a]2+12Ez~Pz[D(G(x))?b]2minGJ(G)=minG12Ez~Pz[D(G(x))?c]2\underset{D}{min}J(D)=\underset{D}{min}\frac{1}{2}E_{x\sim P_{r}}[D(x)-a]^{2}+\frac{1}{2}E_{z\sim P_{z}}[D(G(x))-b]^{2}\\ \underset{G}{min}J(G)=\underset{G}{min}\frac{1}{2}E_{z\sim P_{z}}[D(G(x))-c]^{2}Dmin?J(D)=Dmin?21?Ex~Pr??[D(x)?a]2+21?Ez~Pz??[D(G(x))?b]2Gmin?J(G)=Gmin?21?Ez~Pz??[D(G(x))?c]2作者設置a=c=1,b=0。
參考資料:GAN——LSGANs(最小二乘GAN)
cGAN
GAN的訓練為無監(jiān)督訓練,生成的圖片具有隨機性。為了得到可控的結果,在生成器G與判別器D中均加入給定條件y。這里標簽與生成圖片進行堆疊送入判別器。
目標函數(shù)如下:
參考資料:CGAN論文筆記
詳解GAN代碼之搭建并詳解CGAN代碼
pix2pix
在CGAN基礎上的改進。為了使生成的圖片更接近訓練圖片,加入和L1損失,其損失函數(shù)定義如下:
網(wǎng)絡結構使用了U-Net結構,能夠減少Encoder-Decoder過程中對于原始信息的丟失,其原理如下:
將判別器改變?yōu)榫植颗袆e器(Patch-D),即將圖像分為固定大小的部分送入判別器。
優(yōu)點:
參考資料:Pix2Pix-基于GAN的圖像翻譯
Image-to-Image Translation in Tensorflow
CycleGAN
總結
- 上一篇: Hive的元数据表结构详解(转自lxw1
- 下一篇: 谭浩强-习题4.8