深度学习之生成对抗网络(7)WGAN原理
深度學(xué)習(xí)之生成對(duì)抗網(wǎng)絡(luò)(7)WGAN原理
- 1. JS散度的缺陷
- 2. EM距離
- 3. WGAN-GP
?WGAN算法從理論層面分析了GAN訓(xùn)練不穩(wěn)定的原因,并提出了有效的解決方法。那么是什么原因?qū)е铝薌AN訓(xùn)練如此不穩(wěn)定呢?WGAN提出是因?yàn)镴S散度在不重疊的分布 ppp和 qqq上的梯度曲面是恒定為0的。如下圖所示。當(dāng)分布p和q不重疊時(shí),JS散度的梯度值始終為0,從而導(dǎo)致此時(shí)GAN的訓(xùn)練出現(xiàn)梯度彌散現(xiàn)象,參數(shù)長(zhǎng)時(shí)間得不到更新,網(wǎng)絡(luò)無(wú)法收斂。
圖1. JS散度出現(xiàn)梯度彌散現(xiàn)象
?接下來(lái)我們將詳細(xì)闡述JS散度的缺陷以及怎么解決此缺陷。
1. JS散度的缺陷
為了避免過(guò)多的理論推導(dǎo),我們這里通過(guò)一個(gè)簡(jiǎn)單的分布實(shí)例來(lái)解釋JS散度的缺陷。
考慮完全不重疊(θ≠0θ≠0θ?=0)的兩個(gè)分布ppp和qqq,其中ppp為:
?(x,y)∈p,x=0,y~U(0,1)?(x,y)∈p,x=0,y\sim\text{U}(0,1)?(x,y)∈p,x=0,y~U(0,1)
分布qqq為:
?(x,y)∈q,x=θ,y~U(0,1)?(x,y)∈q,x=θ,y\sim\text{U}(0,1)?(x,y)∈q,x=θ,y~U(0,1)
其中θ∈Rθ∈Rθ∈R,當(dāng)θ=0θ=0θ=0時(shí),分布ppp和qqq重疊,兩者相等;當(dāng)θ≠0θ≠0θ?=0時(shí),分布ppp和qqq不重疊。
?我們來(lái)分析上述分布ppp和qqq之間的JS散度隨θθθ的變化情況。根據(jù)KL散度與JS散度的定義,計(jì)算θ=0θ=0θ=0時(shí)的JS散度DJS(p∣∣q)D_{JS} (p||q)DJS?(p∣∣q):
DKL(p∣∣q)=∑x=0,y~U(0,1)1?log?10=+∞D(zhuǎn)_{KL} (p||q)=∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}?\frac{1}{0}=+∞DKL?(p∣∣q)=x=0,y~U(0,1)∑?1?log?01?=+∞
DKL(q∣∣p)=∑x=θ,y~U(0,1)1?log?10=+∞D(zhuǎn)_{KL} (q||p)=∑_{x=θ,y\sim\text{U}(0,1)}1\cdot\text{log}?\frac{1}{0}=+∞DKL?(q∣∣p)=x=θ,y~U(0,1)∑?1?log?01?=+∞
DJS(p∣∣q)=12(∑x=0,y~U(0,1)1?log11/2+∑x=0,y~U(0,1)1?log11/2)=log?2D_{JS} (p||q)=\frac{1}{2} \bigg(∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}+∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}\bigg)=\text{log}?2DJS?(p∣∣q)=21?(x=0,y~U(0,1)∑?1?log1/21?+x=0,y~U(0,1)∑?1?log1/21?)=log?2
?當(dāng)θ=0θ=0θ=0時(shí),兩個(gè)分布完全重疊,此時(shí)的JS散度和KL散度都取得最小值,即0:
DKL(p∣∣q)=DKL(q∣∣p)=DJS(p∣∣q)=0D_{KL} (p||q)=D_{KL} (q||p)=D_{JS} (p||q)=0DKL?(p∣∣q)=DKL?(q∣∣p)=DJS?(p∣∣q)=0
從上面的推導(dǎo),我們可以得到DJS(p∣∣q)D_{JS} (p||q)DJS?(p∣∣q)隨θθθ的變化趨勢(shì):
DJS(p∣∣q)={log?2θ≠00θ=0D_{JS} (p||q) = \begin{cases} \text{log?}2 &\text{} θ≠0 \\ 0 &\text{} θ=0 \end{cases}DJS?(p∣∣q)={log?20?θ?=0θ=0?
也就是說(shuō),當(dāng)兩個(gè)分布完全不重疊時(shí),無(wú)論發(fā)布之間的距離遠(yuǎn)近,JS散度為恒定值log?2\text{log}?2log?2,此時(shí)JS散度將無(wú)法產(chǎn)生有效的梯度信息;當(dāng)兩個(gè)分布出現(xiàn)重疊時(shí),JS散度采會(huì)平滑變動(dòng),產(chǎn)生有效梯度信息;當(dāng)完全重合后,JS散度取得最小值0.如下圖所示,紅色的曲線分割兩個(gè)正態(tài)分布,由于兩個(gè)分布沒(méi)有重疊,生成樣本位置處的梯度值始終為0,無(wú)法更新生成網(wǎng)絡(luò)的參數(shù),從而出現(xiàn)網(wǎng)絡(luò)訓(xùn)練困難的現(xiàn)象。
?因此,JS散度在分布ppp和qqq不重疊時(shí)是無(wú)法平滑地衡量分布之間的距離,從而導(dǎo)致此位置上無(wú)法產(chǎn)生有效梯度信息,出現(xiàn)GAN訓(xùn)練不穩(wěn)定的情況。要解決此問(wèn)題,需要使用一種更好的分布距離衡量標(biāo)準(zhǔn),使得它即使在分布ppp和qqq不重疊時(shí),也能平滑反映分布之間的真實(shí)距離變化。
2. EM距離
?WGAN論文發(fā)現(xiàn)了JS散度導(dǎo)致GAN訓(xùn)練不穩(wěn)定的問(wèn)題,并引入了一種新的分布距離度量方法:Wasserstein距離,也叫推土機(jī)距離(Earth-Mover Distance,簡(jiǎn)稱EM距離),它表示了從一個(gè)分布變換到另一個(gè)分布的最小代價(jià),定義為:
W(p,q)=infγ~∏(p,q)E(x,y)~γ[∥x?y∥]W(p,q)=\underset{γ\sim∏(p,q)}{\text{inf}}\mathbb E_{(x,y)\simγ} [\|x-y\|]W(p,q)=γ~∏(p,q)inf?E(x,y)~γ?[∥x?y∥]
其中∏(p,q)∏(p,q)∏(p,q)是分布ppp和qqq組合起來(lái)的所有可能的聯(lián)合分布的集合,對(duì)于每個(gè)可能的聯(lián)合分布γ~∏(p,q)γ\sim∏(p,q)γ~∏(p,q),計(jì)算距離∥x?y∥\|x-y\|∥x?y∥的期望E(x,y)~γ[∥x?y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)~γ?[∥x?y∥],其中(x,y)(x,y)(x,y)采樣自聯(lián)合分布γγγ。不同的聯(lián)合分布γγγ由不同的期望E(x,y)~γ[∥x?y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)~γ?[∥x?y∥],這些期望中的下確界即定義為分布ppp和qqq的Wasserstein距離。其中inf?{?}\text{inf}?\{\cdot\}inf?{?}表示集合的下確界,例如{x∣1<x<3,x∈R}\{x|1<x<3,x∈R\}{x∣1<x<3,x∈R}的下確界為1。
?繼續(xù)考慮圖2中的例子,我們直接給出分布ppp和qqq之間的EM距離的表達(dá)式:
W(p,q)=∣θ∣W(p,q)=|θ|W(p,q)=∣θ∣
繪制出JS散度和EM距離的曲線,如下圖所示,可以看到,JS散度在θ=0θ=0θ=0處不連續(xù),其他位置導(dǎo)數(shù)均為0,而EM距離總能夠產(chǎn)生有效的導(dǎo)數(shù)信息,因此EM距離相對(duì)于JS散度更適合直到GAN網(wǎng)絡(luò)的訓(xùn)練。
3. WGAN-GP
?考慮到幾乎不可能遍歷所有的聯(lián)合分布γγγ去計(jì)算距離∥x?y∥\|x-y\|∥x?y∥的期望E(x,y)~γ[∥x?y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)~γ?[∥x?y∥],因此直接計(jì)算生成網(wǎng)絡(luò)分布pgp_gpg?與真實(shí)數(shù)據(jù)數(shù)據(jù)分布prp_rpr?的距離W(pr,pg)W(p_r,p_g )W(pr?,pg?)距離是不現(xiàn)實(shí)的,WGAN作者基于Kantorchovich-Rubin對(duì)偶性將直接求W(pr,pg)W(p_r,p_g )W(pr?,pg?)轉(zhuǎn)換為求:
W(pr,pg)=1Ksup∥f∥L≤KEx~pr[f(x)]?Ex~pg[f(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|f\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [f(x)]-\mathbb E_{x\sim p_g} [f(x)]W(pr?,pg?)=K1?∥f∥L?≤Ksup?Ex~pr??[f(x)]?Ex~pg??[f(x)]
其中sup?{?}\text{sup}?\{\cdot\}sup?{?}表示集合的上確界,∥f∥L≤K\|f\|_L≤K∥f∥L?≤K表示函數(shù)f:R→Rf:R→Rf:R→R滿足K階-Lipschitz連續(xù)性,即滿足
∣f(x1)?f(x2)∣≤K?∣x1?x2∣|f(x_1 )-f(x_2)|≤K\cdot|x_1-x_2 |∣f(x1?)?f(x2?)∣≤K?∣x1??x2?∣
?于是,我們使用判別網(wǎng)絡(luò)Dθ(x)D_θ (\boldsymbol x)Dθ?(x)參數(shù)化f(x)f(\boldsymbol x)f(x)函數(shù),在DθD_θDθ?滿足1階-Lipschitz約束條件下,即K=1K=1K=1,此時(shí):
W(pr,pg)=1Ksup∥Dθ∥L≤KEx~pr[Dθ(x)]?Ex~pg[Dθ(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|D_θ\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]W(pr?,pg?)=K1?∥Dθ?∥L?≤Ksup?Ex~pr??[Dθ?(x)]?Ex~pg??[Dθ?(x)]
因此求解W(pr,pg)W(p_r,p_g )W(pr?,pg?)的問(wèn)題可以轉(zhuǎn)化為:
max?θEx~pr[Dθ(x)]?Ex~pg[Dθ(x)]\underset{θ}{\text{max}?}\ \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]θmax???Ex~pr??[Dθ?(x)]?Ex~pg??[Dθ?(x)]
這就是判別器D的優(yōu)化目標(biāo)。判別網(wǎng)絡(luò)函數(shù)D_θ (x)需要滿足1階-Lipschitz約束:
?x^D(x^)≤1?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})≤1?x^?D(x^)≤1
?在WGAN-GP論文中,作者提出采用增加梯度懲罰項(xiàng)(Gradient Penalty)方法來(lái)迫使判別網(wǎng)絡(luò)滿足1階-Lipschitz函數(shù)約束,同時(shí)作者發(fā)現(xiàn)將梯度值約束在1周?chē)鷷r(shí)工程效果更好,因此梯度懲罰項(xiàng)定義為:
GP?Ex^~Px^[(∥?x^D(x^)∥2?1)2]GP?\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]GP?Ex^~Px^??[(∥?x^?D(x^)∥2??1)2]
因此WGAN的判別器D的訓(xùn)練目標(biāo)為:
maxθL(G,D)=Exr~pr[D(xr)]?Exf~pg[D(xf)]?EM距離?λEx^~Px^[(∥?x^D(x^)∥2?1)2]?GP懲罰項(xiàng)\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距離}-\underbrace{λ\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]}_{GP懲罰項(xiàng)}θmax?L(G,D)=EM距離Exr?~pr??[D(xr?)]?Exf?~pg??[D(xf?)]???GP懲罰項(xiàng)λEx^~Px^??[(∥?x^?D(x^)∥2??1)2]??
其中x^\hat{\boldsymbol x}x^來(lái)自于xr\boldsymbol x_rxr?與xf\boldsymbol x_fxf?的線性差值:
x^=txr+(1?t)xf,t∈[0,1]\hat{\boldsymbol x}=t\boldsymbol x_r+(1-t) \boldsymbol x_f,t∈[0,1]x^=txr?+(1?t)xf?,t∈[0,1]
判別器D的優(yōu)化目標(biāo)是最小化上述的誤差L(G,D)\mathcal L(G,D)L(G,D),即迫使生成器G的分布pgp_gpg?與真實(shí)分布prp_rpr?之間的EM距離Exr~pr[D(xr)]?Exf~pg[D(xf)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]Exr?~pr??[D(xr?)]?Exf?~pg??[D(xf?)]項(xiàng)盡可能大,∥?x^D(x^)∥2\|?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2∥?x^?D(x^)∥2?逼近于1。
?WGAN的生成器G的訓(xùn)練目標(biāo)為:
maxθL(G,D)=Exr~pr[D(xr)]?Exf~pg[D(xf)]?EM距離\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距離}θmax?L(G,D)=EM距離Exr?~pr??[D(xr?)]?Exf?~pg??[D(xf?)]??
即使得生成器的分布pgp_gpg?與真實(shí)分布prp_rpr?之間的EM距離越小越好。考慮到Exr~pr[D(xr)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]Exr?~pr??[D(xr?)]一項(xiàng)與生成器無(wú)關(guān),因此生成器的訓(xùn)練目標(biāo)簡(jiǎn)寫(xiě)為:
maxθL(G,D)=?Exf~pg[D(xf)]=?Ez~pz(?)[D(G(z))]\begin{aligned}\underset{θ}{\text{max}} \mathcal L(G,D)&=-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]\\ &=-E_{\boldsymbol z\sim p_\boldsymbol z (\cdot)} [D(G(\boldsymbol z))]\end{aligned}θmax?L(G,D)?=?Exf?~pg??[D(xf?)]=?Ez~pz?(?)?[D(G(z))]?
?從現(xiàn)實(shí)來(lái)看,判別網(wǎng)絡(luò)D的輸出不需要添加Sigmoid激活函數(shù),這是因?yàn)樵及姹镜呐袆e器的功能是作為二分類網(wǎng)絡(luò),添加Sigmoid函數(shù)獲得類別的概率;而WGAN中判別器作為EM距離的度量網(wǎng)絡(luò),其目標(biāo)是衡量生成網(wǎng)絡(luò)的分布pgp_gpg?和真實(shí)分布prp_rpr?之間的EM距離,屬于實(shí)數(shù)空間,因此不需要添加Sigmoid激活函數(shù)。在誤差函數(shù)計(jì)算時(shí),WGAN也沒(méi)有log\text{log}log函數(shù)存在。在訓(xùn)練WGAN時(shí),WGAN作者推薦使用RMSProp或SGD等不帶動(dòng)量的優(yōu)化器。
?WGAN從理論層面發(fā)現(xiàn)了原始GAN容易出現(xiàn)訓(xùn)練不穩(wěn)定的原因,并給出了一種新的距離度量標(biāo)準(zhǔn)和工程實(shí)現(xiàn)解決方案,取得了較好的效果。WGAN還在一定程度上緩解了模式崩塌的問(wèn)題,使用WGAN的模型不容易出現(xiàn)模式崩塌的現(xiàn)象。需要注意的是,WGAN一般并不能提升模型的生成效果,僅僅是保證了模型訓(xùn)練的穩(wěn)定性。當(dāng)然,保證模型能夠穩(wěn)定地訓(xùn)練也是取得良好效果的前提。如圖5所示,原始版本的DCGAN在不使用BN層等設(shè)定時(shí)出現(xiàn)了訓(xùn)練不穩(wěn)定的現(xiàn)象,在同樣設(shè)定下,使用WGAN來(lái)訓(xùn)練判別器可以避免此現(xiàn)象,如圖6所示。
圖6. 不帶BN層的WGAN生成效果 創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎(jiǎng)勵(lì)來(lái)咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)
總結(jié)
以上是生活随笔為你收集整理的深度学习之生成对抗网络(7)WGAN原理的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 使用linux的dhclient命令动态
- 下一篇: 深度学习之生成对抗网络(8)WGAN-G