【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度
1 對抗神經(jīng)簡介
1.1 對抗神經(jīng)網(wǎng)絡(luò)的基本組成
1.1.1 基本構(gòu)成
對抗神經(jīng)網(wǎng)絡(luò)(即生成式對抗網(wǎng)絡(luò),GAN)一般由兩個模型組成:
- 生成器模型(generator):用于合成與真實樣本相差無幾的模擬樣本。
- 判別器模型(discriminator):用于判斷某個樣本是來自真實世界還是模擬生成的。
1.1.2 不同模型的在GAN中的主要作用
生成器模型的目的是,讓判別器模型將合成樣本當(dāng)成直實樣本。
判別器模的目的是,將合成樣本與真實樣本區(qū)分開。
1.1.3 獨立任務(wù)
若將兩個模型放在一起同步訓(xùn)練,那么生成器模型生成的模擬樣本會更加真實,判別器模型對樣本的判斷會更加精準。
- 生成器模型可以當(dāng)成生成式模型,用來獨立處理生成式任務(wù);
- 判別器模型可以當(dāng)成分類器模型,用來獨立處理分類任務(wù)。
1.2 對抗神經(jīng)網(wǎng)絡(luò)的工作流程
1.2.1生成器模型
生成器模型的輸入是一個隨機編碼向量,輸出是一個復(fù)雜樣本(如圖片)。從訓(xùn)練數(shù)據(jù)中產(chǎn)生相同分布的樣本。對于輸入樣本x,類別標簽y,在生成器模型中估計其聯(lián)合概率分布,生成與輸入樣本x更為相似的樣本。
1.2.2 判別器模型
根據(jù)條件概率分布區(qū)分真假樣本。它的輸入是一個復(fù)雜樣本,輸出是一個概率。這個概率用來判定輸入樣本是真實樣本還是生成器輸出的模擬樣本。
1.2.3 工作流程簡介
生成器模型與判別器模型都采用監(jiān)督學(xué)習(xí)方式進行訓(xùn)練。二者的訓(xùn)練目標相反,存在對抗關(guān)系。將二者結(jié)合后,將形成如下圖所示的網(wǎng)絡(luò)結(jié)構(gòu)。
對抗神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)對抗神經(jīng)網(wǎng)絡(luò)的訓(xùn)練方法各種各樣,但其原理都是一樣的,即在迭代練的優(yōu)化過程中進行兩個網(wǎng)絡(luò)的優(yōu)化。有的方法會在一個優(yōu)化步驟中對兩個網(wǎng)絡(luò)進行優(yōu)化、有的會對兩個網(wǎng)絡(luò)采取不同的優(yōu)化步驟。
經(jīng)過大量的迭代訓(xùn)練會使生成器模型盡可能模擬出“以假亂真”的樣本,而判別模型會有更精確的鑒別真?zhèn)螖?shù)據(jù)的能力,從而使整個對抗神經(jīng)網(wǎng)絡(luò)最終達到所謂的納什均衡,即判別器模型對于生成器模型輸出數(shù)據(jù)的鑒別結(jié)果為50%直、50%假。
1.3 對抗神經(jīng)網(wǎng)絡(luò)的功能
監(jiān)督學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)都屬于判別器模型,自編碼神經(jīng)網(wǎng)絡(luò)中,編碼器部分就屬于一個生成器模型
1.3.1?生成器模型的特性
- 在應(yīng)用數(shù)學(xué)和工程方面,能夠有效地表征高維數(shù)據(jù)分布。
- 在強化學(xué)習(xí)方面,作為一種技術(shù)手段有效表征強化學(xué)習(xí)模型中的狀態(tài)。
- 在半覽督學(xué)習(xí)方面,能夠在數(shù)據(jù)缺失的情況下訓(xùn)練模型、并給出相應(yīng)的輸出。
1.3.2 舉例
在視頻中,通過場景預(yù)測下一幀的場景,而判別器模型的輸出是維度很低的判別結(jié)果和期望輸出的某個預(yù)測值,無法訓(xùn)練出單輸入多輸出的模型。
1.4 Gan模型難以訓(xùn)練的原因
GAN中最終達到對抗的納什均衡只是一個理想狀態(tài),而現(xiàn)實情況是,隨著訓(xùn)練次數(shù)的增多,判別器D的效果漸好,從而總是可以將生成器G的輸出與真實樣本區(qū)分開。
1.4.1 現(xiàn)象剖析
因為生成器G是從低維空間向高維空間(復(fù)雜的樣本空間)的映射,其生成的樣本分布空間Pg難以充滿整個真實樣本的分布空間Pr,即兩個分布完全沒有重疊的部分,或者重疊的部分可忽略,這就使得判別器D可以將其分開。
1.4.2 生成樣本與真實樣本重疊部分可忽略的原因
在二維平面中,隨機取兩條曲線,兩條曲線上的點可以代表二者的分布。要想讓判別器無法分辨它們,需要兩個分布融合在一起,也就是它們之間需要存在重疊的線段,然而這樣的概率為0。即使它們很可能會存在交叉點,但是相比于兩條曲線而言,交叉點比曲線低一個維度,也就是它只是一個點,代表不了分布情況,因此可將其忽略。
1.4.2 原因分析
假設(shè)先將判別器D訓(xùn)練得足夠好,固定判別器D后再來訓(xùn)練生成器G,通過實驗會發(fā)現(xiàn)G的loss值無法收斂到最小值,而是無限地接近一個特定值。這個值可以理解為模擬樣本分布Pg與原始樣本分布Pr兩個樣本分布之間的距離。對于loss值恒定(即表明生成器G的梯度為0)的情況,生成器G無法通過訓(xùn)練來優(yōu)化自己。
在原始GAN的訓(xùn)練中判別器訓(xùn)練得太好,生成器梯度就會逍失,生成器的lossS值降不下去;
在原始GAN的訓(xùn)練中判別器訓(xùn)練得不好,生成器梯度不準,抖動較大。
只有判別器訓(xùn)練到中間狀態(tài),才是最好的,但是這個尺度很難把握,甚至在同一輪訓(xùn)練的不同階段這個狀態(tài)出現(xiàn)的時段都不一樣,這是一個完全不可控的情況。
2 WGAN模型
WGAN的名字源于Wasserstein GAN,Vasserstein是指Wasserstein距離,又稱Earth-Mover(EM)推土機距離。
2.1 WGAN模型的原理
WGAN的原理是將生成的模擬樣本分布Pg與原始樣本分布Pr組合起來,并作為所有可能的聯(lián)合分布的集合,并計算出二者的距離和距離的期望值。
2.1.1 WGAN原理的優(yōu)勢
可以通過訓(xùn)練模型的方式,讓網(wǎng)絡(luò)沿著其該網(wǎng)絡(luò)所有可能的聯(lián)合分布期望值的下界方向進行優(yōu)化,即將兩個分布的集合拉到一起。此時,原來的判別器就不再具有判別真?zhèn)蔚墓δ?#xff0c;而獲得計算兩個分布集合距離的功能。因此,將其稱為評論器會更加合適。最后一層的Sigmoid函數(shù)也需要去掉(不需要將值域控制在0~1)。
2.2 WGAN模型的實現(xiàn)
使用神經(jīng)網(wǎng)絡(luò)來計算Wasserstein距離,可以讓神經(jīng)網(wǎng)絡(luò)直接擬合下式:
f(x)可以理解成神經(jīng)網(wǎng)絡(luò)的計算,讓判別器實現(xiàn)將f(x1)與f(x2)的距離變換成x1-x2的絕對值乘以k(k≥0)。k代表函數(shù)f(x)的Lipschitz常數(shù),這樣兩個分布集合的距離就可以表示成D(real)-D(G(x))的絕對值乘以k了。這個k可以理解成梯度,即在神經(jīng)網(wǎng)絡(luò)f(x)中乘以的梯度絕對值小于k。
將上式中的k忽略,經(jīng)過整理后,可以得到二者分布的距離公式:
現(xiàn)在要做的就是將L當(dāng)成目標來計算loss值。
判別器D的任務(wù)是區(qū)分它們,因為希望二者距離變大,所以loss值需要取反得到:
通過判別器D的losss值也可以看出生成器G的生成質(zhì)量,即loss值越小,代表距離越近,生成的質(zhì)量越高。
生成器G用來將希望模擬樣本分布Pg越來越接近原始樣本分布Pr,所以需要訓(xùn)練讓距離L最小化。因為生成器G與第一項無關(guān),所以G的loss值口可簡化為:
2.4 WGAN的缺點
若原始WGAN的Lipschitz限制的施加方式不對,那么使用梯度截斷方式太過生硬。每當(dāng)更新完一次判別器的參數(shù)之后,就應(yīng)檢查判別器中所有參數(shù)的絕對值有沒有超過閾值,有的話就把這些參數(shù)截斷回[-0.01,0.01]范圍內(nèi)。
Lipschitz限制本意是當(dāng)輸入的樣本稍微變化后,判別器給出的分數(shù)不能產(chǎn)生太過劇烈的變化。通過在訓(xùn)練過程中保證判別器的所有參數(shù)有界,可保證判別器不能對兩個略微不同的樣本給出天差地別的分數(shù)值,從而間接實現(xiàn)了Lipschitz限制。
這種期望與判別器本身的目的相矛盾。判別器中希望loss值盡可能大,這樣才能拉大真假樣本間的區(qū)別,但是這種情況會導(dǎo)致在判別器中,通過loss值算出來的梯度會沿著loss值越來越大的方向變化,然而經(jīng)過梯度截斷后每一個網(wǎng)絡(luò)參數(shù)又被獨立地限制了取值范圃(如[-0.01,0.01])。這種結(jié)果會使得所有的參數(shù)要么取最大值(如0.01),要么取最小值(如-0.01)。判別器沒能充分利用自身的模型能力,經(jīng)過它回傳給生成器的梯度也會跟著變差。
如果判別器是一個多層網(wǎng)絡(luò),那么梯度截斷還會導(dǎo)致梯度消失或者梯度“爆炸”問題。截斷閥值設(shè)置得稍微低一點,那么每經(jīng)過一層網(wǎng)絡(luò),梯度就會變小一點,多層之后就會呈指數(shù)衰減趨勢。
反之截斷闊值設(shè)置得稍大,每經(jīng)過一層網(wǎng)絡(luò),梯度變大一點,則多層之后就會呈指數(shù)爆炸趨勢。在實際應(yīng)用中,很難做到設(shè)合適,讓生或器獲得恰到好處的回傳梯度。
2.3 WGAN模型總結(jié)
WGAN引入了Wasserstein距離,由于它相對KL散度與JS散度具有優(yōu)越的平滑特性,因此理論上可以解決梯度消失問題。再利用一個參數(shù)數(shù)值范圍受限的判別器神經(jīng)網(wǎng)絡(luò)實現(xiàn)將Wasserstein距離數(shù)學(xué)變換寫成可求解的形式的最大化,可近似得到Wasserstein距離。
在此近似最優(yōu)判別器下,優(yōu)化生成器使得Wasserstein距離縮小,這能有效拉近生成分布與真實分布。WGAN既解決了訓(xùn)練不穩(wěn)定的問題,又提供了一個可靠的訓(xùn)練進程指標,而且該指標確實與生成樣本的質(zhì)量高度相關(guān)。
在實際訓(xùn)練過程中,WGAN直接使用截斷(clipping)的方式來防止梯度過大或過小。但這個方式太過生硬,在實際應(yīng)用中仍會出現(xiàn)問題,所以后來產(chǎn)生了其升級版WGAN-gp。
3 WGAN-gp模型(更容易訓(xùn)練的GAN模型)
WGAN-gp又稱為具有梯度懲罰的WGAN,是WGAN的升級版,一般可以用來全面代替WGAN。
3.1 WGAN-gp介紹
WGAN-gp中的gp是梯度懲罰(gradient penalty)的意思,是替換weight clipping的一種方法。通過直接設(shè)置一個額外的梯度懲罰項來實現(xiàn)判別器的梯度不超過k。其表達公式為:
其中,MSE為平方差公式;X_inter為整個聯(lián)合分布空間中的x取樣,即梯度懲罰項gradent _penaltys為求整個聯(lián)合分布空間中x對應(yīng)D的梯度與k的平方差。
3.2?WGAN-gp的原理與實現(xiàn)
3.3 Tip
- 因為要對每個樣本獨立地施加梯度懲罰,所以在判別器的模型架構(gòu)中不能使用BN算法,因為它會引入同一個批次中不同樣本的相互依賴關(guān)系。
- 如果需要的話,那么可以選擇其他歸一化辦法,如Layer Normalization、Weight Normalization、Instance Normalization等,這些方法不會引入樣本之間的依賴。
4 條件GAN
條件GAN的作用是可以讓GAN的生成器模型按照指定的類別生成模擬樣本。
4.1 條件GAN的實現(xiàn)
條件GAN在GAN的生成器和判別器基礎(chǔ)上各進行了一處改動;在它們的輸入部分加入了一個標簽向量(one_hot類型)。
4.2?條件GAN的原理
GAN的原理與條件變分自編碼神經(jīng)網(wǎng)絡(luò)的原理一樣。這種做法可以理解為給GAN增加一個條件,讓網(wǎng)絡(luò)學(xué)習(xí)圖片分布時加入標簽因素,這樣可以按照標簽的數(shù)值來生成指定的圖片。
5 帶有散度的GAN——WGAN-div
WGAN-div模型在WGAN-gp的基礎(chǔ)上,從理論層面進行了二次深化。在WGAN-gp中,將判別器的梯度作為懲罰項加入判別器的loss值中。
在計算判別器梯度時,為了讓X_inter從整個聯(lián)合分布空間的x中取樣,在真假樣本之間采取隨機取樣的方式,保證采樣區(qū)間屬于真假樣本的過渡區(qū)域。然而,這種方案更像是一種經(jīng)驗方案,沒有更完備的理論支撐(使用個體采樣代替整體分布,而沒能從整體分布層面直接解決問題)。
3.1?WGAN-div模型的使用思路
WGAN-div模型與WGAN-gp相比,有截然不同的使用思路:不從梯度懲罰的角度去考慮,而通過兩個樣本間的分布距離來實現(xiàn)。
在WGAN-diⅳ模型中,引入了W散度用于度量真假樣本分布之間的距離,并證明了中的W距離不是散度。這意味著WGAN-gp在訓(xùn)練判別器的時候,并非總會拉大兩個分布間的距離,從而在理論上證明了WGAN-gp存在的缺陷一—會有訓(xùn)練失效的情況。
WGAN-div模型從理論層面對WGAN進行了補充。利用WGAN-div模型的理論所實現(xiàn)的loss值不再需要采樣過程,并且所達到的訓(xùn)練效果也比WGAN-gp更勝一籌。
3.2 了解W散度
?3.3 WGAN-div的損失函數(shù)
3.4?W散度與W距離間的關(guān)系
總結(jié)
以上是生活随笔為你收集整理的【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: php l方法,ThinkPHP的L方法
- 下一篇: 【Pytorch神经网络理论篇】 16