TensorFlow2-生成对抗网络
What i can not create, i do not understand. 我不能創造的東西,我當然不能理解它。
簡介
對抗生成網絡(GAN)是時下非常熱門的一種神經網絡,它主要用于復現數據的分布(distribution,或者數據的表示(representation))。盡管數據的分布非常的復雜,但是依靠神經網絡強大的學習能力,可以學習其中的表示。其中,最典型的技術就是圖像生成。GAN的出現是神經網絡技術發展極具突破的一個創新。從2014年GAN誕生之時只能和VAE旗鼓相當,到2018年WGAN的以假亂真,GAN的發展是迅速的。
原理
GAN網絡由兩個部分組成,它們是生成器(Generator)和判別器(Discriminator)。將輸入數據與生成器產生的數據同時交給判別器檢驗,如果兩者的分布接近(p_g接近p_r),則表示生成器逐漸學習數據的分布,當接近到一定程度(判別器無法判別生成數據的真假),認為學習成功。
因此關于生成器G和判別器D之間的優化目標函數如下,這就是GAN網絡訓練的目標。
min?Gmax?DL(D,G)=Ex~pr(x)[log?D(x)]+Ez~pz(z)[log?(1?D(G(z)))]=Ex~pr(x)[log?D(x)]+Ex~pz(x)[log?(1?D(x)]\begin{aligned} \min _{G} \max _{D} L(D, G) &=\mathbb{E}_{x \sim p_{r}(x)}[\log D(x)]+\mathbb{E}_{z \sim p_{z}(z)}[\log (1-D(G(z)))] \\ &=\mathbb{E}_{x \sim p_{r}(x)}[\log D(x)]+\mathbb{E}_{x \sim p_{z}(x)}[\log (1-D(x)]\end{aligned} Gmin?Dmax?L(D,G)?=Ex~pr?(x)?[logD(x)]+Ez~pz?(z)?[log(1?D(G(z)))]=Ex~pr?(x)?[logD(x)]+Ex~pz?(x)?[log(1?D(x)]?
衡量兩種分布之間的距離,GAN使用JS散度(基于KL散度推導)衡量兩種分布的差異,然而當兩種分布(生成器分布和真實分布)直接沒有交叉時,KL散度總是0,JS散度總是log2,這就導致JS散度無法很好量化兩種分布的差異。同時,此時的將會出現梯度彌散,這也是很多GAN網絡難以訓練的原因。
因此,有人提出了衡量兩種分布P和Q之間差異的方式是從P分布到Q分布需要經歷的變化(代價),可以理解為下圖的一種分布變為另一種分布需要移動的磚塊數目(移土距離,Earth Mover’s Distance, EM距離)。
B(γ)=∑xp,xqγ(xp,xq)∥xp?xq∥B(\gamma)=\sum_{x_{p}, x_{q}} \gamma\left(x_{p}, x_{q}\right)\left\|x_{p}-x_{q}\right\| B(γ)=xp?,xq?∑?γ(xp?,xq?)∥xp??xq?∥
W(P,Q))=min?γ∈ΠB(γ)W(P, Q))=\min _{\gamma \in \Pi} B(\gamma) W(P,Q))=γ∈Πmin?B(γ)
基于此提出了Wasserstein Distance距離如下,將網絡中的JS散度替換為Wasserstein Distance的GAN,稱為WGAN,它可以從根本上結局不重疊的分布距離難以衡量的問題從而避免訓練早期的梯度彌散。(必須滿足1-Lipschitz function,為了滿足這個條件要進行weight clipping,但是即使weight clipping也不一定可以滿足1-Lipschitz function條件。)
W(Pr,Pg)=inf?γ∈Π(Pr,Pg)E(x,y)~γ[∥x?y∥]W\left(\mathbb{P}_{r}, \mathbb{P}_{g}\right)=\inf _{\gamma \in \Pi\left(\mathbb{P}_{r}, \mathbb{P}_{g}\right)} \mathbb{E}_{(x, y) \sim \gamma}[\|x-y\|] W(Pr?,Pg?)=γ∈Π(Pr?,Pg?)inf?E(x,y)~γ?[∥x?y∥]
因此,為了滿足這個條件提出了WGAN-GP(Gradient Penalty),將這個條件寫入損失函數,要求必須在1附近。
GAN發展
從GAN思路被提出以來,產生了各種各樣的GAN,每一種GAN都有自己的名字,一般以首字母簡略稱呼(如今A-Z已經幾乎用完,可見這幾年GAN的發展迅速)。
其中,比較著名的有DCGAN(反卷積GAN,用于圖片擴張)。
此外,還有LSGAN、WGAN(盡管效果不如DCGAN,但是不需要花太多精力設計訓練過程)等。
GAN實戰
基于日本Anime數據集生成相應的二次元人物頭像,數據集的百度網盤地址如下,提取碼g5qa。
構建的GAN模型結構示意如下,判別器是一個基礎的CNN分類器,生成器是將隨機生成的數據進行升維成圖。
下面給出模型結構代碼,具體的訓練代碼可以在文末Github找到。
WGAN只需要在GAN代碼基礎上添加懲罰項,具體見Github。
補充說明
- 本文介紹了GAN在TensorFlow2中的實現,更詳細的可以查看官方文檔。
- 具體的代碼同步至我的Github倉庫歡迎star;博客同步至我的個人博客網站,歡迎查看其他文章。
- 如有疏漏,歡迎指正。
總結
以上是生活随笔為你收集整理的TensorFlow2-生成对抗网络的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TensorFlow2-自编码器
- 下一篇: Mathpix教程