简述生成式对抗网络 GAN
【轉(zhuǎn)載請(qǐng)注明出處】chenrudan.github.io
本文主要闡述了對(duì)生成式對(duì)抗網(wǎng)絡(luò)的理解,首先談到了什么是對(duì)抗樣本,以及它與對(duì)抗網(wǎng)絡(luò)的關(guān)系,然后解釋了對(duì)抗網(wǎng)絡(luò)的每個(gè)組成部分,再結(jié)合算法流程和代碼實(shí)現(xiàn)來解釋具體是如何實(shí)現(xiàn)并執(zhí)行這個(gè)算法的,最后給出一個(gè)基于對(duì)抗網(wǎng)絡(luò)改寫的去噪網(wǎng)絡(luò)運(yùn)行的結(jié)果,效果雖然挺差的,但是有些地方還是挺有意思的。
- 1. 對(duì)抗樣本
- 2. 生成式對(duì)抗網(wǎng)絡(luò)GAN
- 3. 代碼解釋
- 4. 運(yùn)行實(shí)例
- 5. 小結(jié)
- 6. 引用
1. 對(duì)抗樣本(adversarial examples)
14年的時(shí)候Szegedy在研究神經(jīng)網(wǎng)絡(luò)的性質(zhì)時(shí),發(fā)現(xiàn)針對(duì)一個(gè)已經(jīng)訓(xùn)練好的分類模型,將訓(xùn)練集中樣本做一些細(xì)微的改變會(huì)導(dǎo)致模型給出一個(gè)錯(cuò)誤的分類結(jié)果,這種雖然發(fā)生擾動(dòng)但是人眼可能識(shí)別不出來,并且會(huì)導(dǎo)致誤分類的樣本被稱為對(duì)抗樣本,他們利用這樣的樣本發(fā)明了對(duì)抗訓(xùn)練(adversarial training),模型既訓(xùn)練正常的樣本也訓(xùn)練這種自己造的對(duì)抗樣本,從而改進(jìn)模型的泛化能力[1]。如下圖所示,在未加擾動(dòng)之前,模型認(rèn)為輸入圖片有57.7%的概率為熊貓,但是加了之后,人眼看著好像沒有發(fā)生改變,但是模型卻認(rèn)為有99.3%的可能是長臂猿。
圖1 對(duì)抗樣本的產(chǎn)生(圖來源[2])這個(gè)問題乍一看很像過擬合,在Goodfellow在15年[3]提到了其實(shí)模型欠擬合也能導(dǎo)致對(duì)抗樣本,因?yàn)閺默F(xiàn)象上來說是輸入發(fā)生了一定程度的改變就導(dǎo)致了輸出的不正確,例如下圖一,上下分別是過擬合和欠擬合導(dǎo)致的對(duì)抗樣本,其中綠色的o和x代表訓(xùn)練集,紅色的o和x即對(duì)抗樣本,明顯可以看到欠擬合的情況下輸入發(fā)生改變也會(huì)導(dǎo)致分類不正確(其實(shí)這里我覺得有點(diǎn)奇怪,因?yàn)閳D中所描述的對(duì)抗樣本不一定就是跟原始樣本是同分布的,感覺是人為造的一個(gè)東西,而不是真實(shí)數(shù)據(jù)的反饋)。在[1]中作者覺得這種現(xiàn)象可能是因?yàn)樯窠?jīng)網(wǎng)絡(luò)的非線性和過擬合導(dǎo)致的,但Goodfellow卻給出了更為準(zhǔn)確的解釋,即對(duì)抗樣本誤分類是因?yàn)槟P偷木€性性質(zhì)導(dǎo)致的,說白了就是因?yàn)?span id="ozvdkddzhkzd" class="MathJax" id="MathJax-Element-1-Frame" tabindex="0" style="display:inline; line-height:normal; text-align:left; word-spacing:normal; word-wrap:normal; white-space:nowrap; float:none; direction:ltr; max-width:none; max-height:none; min-width:0px; min-height:0px; border:0px; padding:0px; margin:0px; position:relative">wTxwTx存在點(diǎn)乘,當(dāng)xx的每一個(gè)維度上都發(fā)生改變x?=x+ηx~=x+η,就會(huì)累加起來在點(diǎn)乘的結(jié)果上附加上一個(gè)比較大的和wTx?=wTx+wTηwTx~=wTx+wTη,而這個(gè)值可能就改變了預(yù)測(cè)結(jié)果。例如[4]中給出的一個(gè)例子,假設(shè)現(xiàn)在用邏輯回歸做二分類,輸入向量是x=[2,?1,3,?2,2,2,1,?4,5,1]x=[2,?1,3,?2,2,2,1,?4,5,1],權(quán)重向量是w=[?1,?1,1,?1,1,?1,1,1,?1,1]w=[?1,?1,1,?1,1,?1,1,1,?1,1],點(diǎn)乘結(jié)果是-3,類預(yù)測(cè)為1的概率為0.0474,假如將輸入變?yōu)?span id="ozvdkddzhkzd" class="MathJax" id="MathJax-Element-7-Frame" tabindex="0" style="display:inline; line-height:normal; text-align:left; word-spacing:normal; word-wrap:normal; white-space:nowrap; float:none; direction:ltr; max-width:none; max-height:none; min-width:0px; min-height:0px; border:0px; padding:0px; margin:0px; position:relative">xad=x+0.5w=[1.5,?1.5,3.5,?2.5,2.5,1.5,1.5,?3.5,4.5,1.5]xad=x+0.5w=[1.5,?1.5,3.5,?2.5,2.5,1.5,1.5,?3.5,4.5,1.5],那么類預(yù)測(cè)為1的概率就變成了0.88,就因?yàn)檩斎朐诿總€(gè)維度上的改變,導(dǎo)致了前后的結(jié)果不一致。
圖2 過/欠擬合導(dǎo)致對(duì)抗樣本(圖來源[3])如果認(rèn)為對(duì)抗樣本是因?yàn)槟P偷木€性性質(zhì)導(dǎo)致的,那么是否能夠構(gòu)造出一個(gè)方法來生成對(duì)抗樣本,即如何在輸入上加擾動(dòng),Goodfellow給出了一種構(gòu)造方法fast gradient sign method[2],其中JJ是損失函數(shù),再對(duì)輸入xx求導(dǎo),θθ是模型參數(shù),??是一個(gè)非常小的實(shí)數(shù)。圖1中就是?=0.007?=0.007。
η=?sign(▽xJ(θ,x,y))(1)η=?sign(▽xJ(θ,x,y))(1)這個(gè)構(gòu)造方法在[4]中有比較多的實(shí)例,這里截取了兩個(gè)例子來說明,用imagenet圖片縮放到64*64來訓(xùn)練一個(gè)一層的感知機(jī),輸入是64*64*3,輸出是1000,權(quán)重是64*64*3*1000,訓(xùn)練好之后取權(quán)重矩陣對(duì)應(yīng)某個(gè)輸出類別的一行64*64*3,將這行還原成64*64圖片顯示為下圖中第二列,再用公式1的方法從第一列的原始圖片中算出第三列的對(duì)抗樣本,可以看到第一行從預(yù)測(cè)為狐貍變成了預(yù)測(cè)為金魚,第二行變成了預(yù)測(cè)為校車。
圖3 構(gòu)造對(duì)抗樣本(圖來源[4])實(shí)際上不是只有純線性模型才會(huì)出現(xiàn)這種情況,卷積網(wǎng)絡(luò)的卷積其實(shí)就是線性操作,因此也有預(yù)測(cè)不穩(wěn)定的情況,relu/maxout甚至sigmoid的中間部分其實(shí)也算是線性操作。因?yàn)榭梢宰约簶?gòu)造對(duì)抗樣本,那么就能應(yīng)用這個(gè)性質(zhì)來訓(xùn)練模型,讓模型泛化能力更強(qiáng)。因而[2]給定了一種新的目標(biāo)函數(shù)也就是下面的式子,相當(dāng)于對(duì)輸入加入一些干擾,并且也通過實(shí)驗(yàn)結(jié)果證實(shí)了訓(xùn)練出來的模型更加能夠抵抗對(duì)抗樣本的影響。
J?(θ,x,y)=αJ(θ,x,y)+(1?α)J(θ,x+?sign(▽xJ(θ,x,y)))(2)J~(θ,x,y)=αJ(θ,x,y)+(1?α)J(θ,x+?sign(▽xJ(θ,x,y)))(2)對(duì)抗樣本跟生成式對(duì)抗網(wǎng)絡(luò)沒有直接的關(guān)系,對(duì)抗網(wǎng)絡(luò)是想學(xué)樣本的內(nèi)在表達(dá)從而能夠生成新的樣本,但是有對(duì)抗樣本的存在在一定程度上說明了模型并沒有學(xué)習(xí)到數(shù)據(jù)的一些內(nèi)部表達(dá)或者分布,而可能是學(xué)習(xí)到一些特定的模式足夠完成分類或者回歸的目標(biāo)而已。公式1的構(gòu)造方法只是在梯度方向上做了一點(diǎn)非常小的變化,但是模型就無法正確的分類。此外還觀察到一個(gè)現(xiàn)象,用不同結(jié)構(gòu)的多個(gè)分類器來學(xué)習(xí)相同數(shù)據(jù),往往會(huì)將相同的對(duì)抗樣本誤分到相同的類中,這個(gè)現(xiàn)象看上去是所有的分類器都被相同的變化所干擾了。
2. 生成式對(duì)抗網(wǎng)絡(luò)GAN
14年Goodfellow提出Generative adversarial nets即生成式對(duì)抗網(wǎng)絡(luò)[5],它要解決的問題是如何從訓(xùn)練樣本中學(xué)習(xí)出新樣本,訓(xùn)練樣本是圖片就生成新圖片,訓(xùn)練樣本是文章就輸出新文章等等。如果能夠知道訓(xùn)練樣本的分布p(x)p(x),那么就可以在分布中隨機(jī)采樣得到新樣本,大部分的生成式模型都采用這種思路,GAN則是在學(xué)習(xí)從隨機(jī)變量zz到訓(xùn)練樣本xx的映射關(guān)系,其中隨機(jī)變量可以選擇服從正太分布,那么就能得到一個(gè)由多層感知機(jī)組成的生成網(wǎng)絡(luò)G(z;θg)G(z;θg),網(wǎng)絡(luò)的輸入是一個(gè)一維的隨機(jī)變量,輸出是一張圖片。如何讓輸出的偽造圖片看起來像訓(xùn)練樣本,Goodfellow采用了這樣一種方法,在生成網(wǎng)絡(luò)后面接上一個(gè)多層感知機(jī)組成的判別網(wǎng)絡(luò)D(x;θd)D(x;θd),這個(gè)網(wǎng)絡(luò)的輸入是隨機(jī)選擇一張真實(shí)樣本或者生成網(wǎng)絡(luò)的輸出,輸出是輸入圖片來自于真實(shí)樣本pdatapdata或者生成網(wǎng)絡(luò)pgpg的概率,當(dāng)判別網(wǎng)絡(luò)能夠很好的分辨出輸入是不是真實(shí)樣本時(shí),也能通過梯度的方式說明什么樣的輸入更加像真實(shí)樣本,從而通過這個(gè)信息來調(diào)整生成網(wǎng)絡(luò)。從而GG需要盡可能的讓自己的輸出像真實(shí)樣本,而DD則盡可能的將不是真實(shí)樣本的情況分辨出來。下圖左邊是GAN算法的概率解釋,右邊是模型構(gòu)成。
圖4 GAN算法框圖(圖來源[6])GAN的優(yōu)化是一個(gè)極小極大博弈問題,最終的目的是generator的輸出給discriminator時(shí)很難判斷是真實(shí)or偽造的,即極大化DD的判斷能力,極小化將GG的輸出判斷為偽造的概率,公式如下。論文[5]中將下面式子轉(zhuǎn)化成了Jensen-shannon散度的形式證明了僅當(dāng)pg=pdatapg=pdata時(shí)能得到全局最小值,即生成網(wǎng)絡(luò)能完全的還原出真實(shí)樣本分布,并且證明了下式能夠收斂。(算法流程論文講的很清楚,這里就不說了,后面結(jié)合代碼一起解釋。)
minGmaxDV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[log(1?D(G(z)))](3)minGmaxDV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[log(1?D(G(z)))](3)以上是關(guān)于最基本GAN的介紹,最開始我看了論文后產(chǎn)生了幾個(gè)疑問,1.為什么不能直接學(xué)習(xí)GG,即直接學(xué)習(xí)一個(gè)zz到一個(gè)xx?2.GG具體是如何訓(xùn)練的?3.在訓(xùn)練的時(shí)候zz跟xx是一一對(duì)應(yīng)關(guān)系嗎?在對(duì)代碼理解之后大概能夠給出一個(gè)解釋。
3. 代碼解釋
這部分主要結(jié)合tensorflow實(shí)現(xiàn)代碼[7]、算法流程和下面的變化圖[5]解釋一下具體如何使用DCGAN來生成手寫體圖片。
下圖中黑色虛線是真實(shí)數(shù)據(jù)的高斯分布,綠色的線是生成網(wǎng)絡(luò)學(xué)習(xí)到的偽造分布,藍(lán)色的線是判別網(wǎng)絡(luò)判定為真實(shí)圖片的概率,標(biāo)x的橫線代表服從高斯分布x的采樣空間,標(biāo)z的橫線代表服從均勻分布z的采樣空間。可以看出GG就是學(xué)習(xí)了從z的空間到x的空間的映射關(guān)系。
圖5 GAN運(yùn)行時(shí)各個(gè)概率分布圖(圖來源[5])a.起始情況
DD是一個(gè)卷積神經(jīng)網(wǎng)絡(luò),變量名是D,其中一層構(gòu)造方式如下。
| 12345678 | w = tf.get_variable('w', [4, 4, c_dim, num_filter], initializer=tf.truncated_normal_initializer(stddev=stddev))dconv = tf.nn.conv2d(ddata, w, strides=[1, 2, 2, 1], padding='SAME')biases = tf.get_variable('biases', [num_filter], initializer=tf.constant_initializer(0.0))bias = tf.nn.bias_add(dconv, biases)dconv1 = tf.maximum(bias, leak*bias)... |
GG是一個(gè)逆卷積神經(jīng)網(wǎng)絡(luò),變量名是G,其中一層構(gòu)造方式如下。
| 12345678910 | w = tf.get_variable('w', [4, 4, num_filter, num_filter*2], initializer=tf.random_normal_initializer(stddev=stddev))deconv = tf.nn.conv2d_transpose(gconv2, w, output_shape=[batch_size, s2, s2, num_filter], strides=[1, 2, 2, 1])biases = tf.get_variable('biases', [num_filter],initializer=tf.constant_initializer(0.0))bias = tf.nn.bias_add(deconv, biases)deconv1 = tf.nn.relu(bias, name=scope.name)... |
GG的網(wǎng)絡(luò)輸入為一個(gè)zdimzdim維服從-1~1均勻分布的隨機(jī)變量,這里取的是100.
| 12 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]).astype(np.float32) |
DD的網(wǎng)絡(luò)輸入是一個(gè)batch的64*64的圖片,既可以是手寫體數(shù)據(jù)也可以是GG的一個(gè)batch的輸出。
這個(gè)過程可以參考上圖的a狀態(tài),判別曲線處于不夠穩(wěn)定的狀態(tài),兩個(gè)網(wǎng)絡(luò)都還沒訓(xùn)練好。
b.訓(xùn)練判別網(wǎng)絡(luò)
判別網(wǎng)絡(luò)的損失函數(shù)由兩部分組成,一部分是真實(shí)數(shù)據(jù)判別為1的損失,一部分是GG的輸出self.G判別為0的損失,需要優(yōu)化的損失函數(shù)定義如下。
| 123456789 | self.G = self.generator(self.z)self.D, self.D_logits = self.discriminator(self.images)self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))self.d_loss = self.d_loss_real + self.d_loss_fake |
然后將一個(gè)batch的真實(shí)數(shù)據(jù)batch_images,和隨機(jī)變量batch_z當(dāng)做輸入,執(zhí)行session更新DD的參數(shù)。
| 123456 | # update discriminator on reald_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1).minimize(d_loss, var_list=d_vars)...out1 = sess.run([d_optim], feed_dict={real_images: batch_images, noise_images: batch_z}) |
這一步可以對(duì)比圖b,判別曲線漸漸趨于平穩(wěn)。
c.訓(xùn)練生成網(wǎng)絡(luò)
生成網(wǎng)絡(luò)并沒有一個(gè)獨(dú)立的目標(biāo)函數(shù),它更新網(wǎng)絡(luò)的梯度來源是判別網(wǎng)絡(luò)對(duì)偽造圖片求的梯度,并且是在設(shè)定偽造圖片的label是1的情況下,保持判別網(wǎng)絡(luò)不變,那么判別網(wǎng)絡(luò)對(duì)偽造圖片的梯度就是向著真實(shí)圖片變化的方向。
| 12 | self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_))) |
然后用同樣的隨機(jī)變量batch_z當(dāng)做輸入更新
| 1234 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) .minimize(self.g_loss, var_list=self.g_vars)...out2 = sess.run([g_optim], feed_dict={noise_images:batch_z}) |
這一步可以對(duì)比圖c,pgpg的曲線在漸漸的向真實(shí)分布靠攏。而網(wǎng)絡(luò)訓(xùn)練完成之后可以看到pgpg的曲線與pdatapdata重疊在了一起,并且此時(shí)判別網(wǎng)絡(luò)已經(jīng)難以區(qū)分真實(shí)與偽造,因此取值就固定在了1212。
因而針對(duì)我之前的問題,2已經(jīng)有了答案,針對(duì)1,為什么不能直接學(xué)習(xí)GG?這是因?yàn)闊o法確定zz與xx的一一對(duì)應(yīng)關(guān)系,就像下圖,兩種對(duì)應(yīng)關(guān)系,如果要肯定誰是對(duì)誰是錯(cuò),那么就得加入一些先驗(yàn)信息,甚至是直接對(duì)真實(shí)樣本的估計(jì),那么跟其他的方法不就一樣了么。而問題3,在訓(xùn)練的時(shí)候zz跟xx是一一對(duì)應(yīng)關(guān)系嗎?我開始考慮這個(gè)問題是因?yàn)椴磺宄遣皇且粋€(gè)100維的noise變量就對(duì)應(yīng)著一個(gè)手寫體變量圖片,但是現(xiàn)在考慮一下就應(yīng)該明白在訓(xùn)練的層面上不是一一對(duì)應(yīng)的,甚至兩者在訓(xùn)練DD的時(shí)候都是分開的,只是可能在分布中會(huì)存在這樣一種對(duì)應(yīng)關(guān)系而已。
圖6 z與x映射圖(圖來源[8])4. 運(yùn)行實(shí)例
這里本來想用GAN來跑一個(gè)去噪的網(wǎng)絡(luò),基于[7]的代碼改了一下輸入,從一個(gè)100維的noise向量變成了一張輸入圖片,同時(shí)將generator網(wǎng)絡(luò)的前面部分變成了卷積網(wǎng)絡(luò),再連上原來的逆卷積,就成了一個(gè)去噪網(wǎng)絡(luò),這里我沒太多時(shí)間來細(xì)致的調(diào)節(jié)網(wǎng)絡(luò)層數(shù)、參數(shù)等,就隨便試了一下,效果也不是特別的好。代碼在[9]中。首先我通過read_stl10.py對(duì)stl10數(shù)據(jù)集加上了均值為0方差為50的高斯噪聲,前后對(duì)比如下。
圖7 增加高斯噪聲前后對(duì)比然后執(zhí)行對(duì)抗網(wǎng)絡(luò),會(huì)得到如下的去噪效果,從左到右分別是加了噪聲的輸入圖片,對(duì)應(yīng)的generator網(wǎng)絡(luò)的輸出圖片,已經(jīng)對(duì)應(yīng)的干凈圖片,效果不是特別好,輪廓倒是能學(xué)到一點(diǎn),但是這個(gè)顏色卻沒學(xué)到。
圖8 去噪對(duì)比5. 小結(jié)
剛開始搜資料的時(shí)候發(fā)現(xiàn)了對(duì)抗樣本,以為跟對(duì)抗網(wǎng)絡(luò)有關(guān)系,就看了一下,后來看Goodfellow的論文時(shí)發(fā)現(xiàn)其實(shí)沒什么關(guān)系,但是還是寫了一些內(nèi)容,因?yàn)檫@個(gè)東西的存在還是值得了解的,而對(duì)抗網(wǎng)絡(luò)這個(gè)想法真的太贊了,它將一個(gè)無監(jiān)督問題轉(zhuǎn)化為有監(jiān)督,更加像一種learn的方式來學(xué)習(xí)數(shù)據(jù)應(yīng)該是如何產(chǎn)生,而不是find的方式來找某些特征,但是訓(xùn)練也是一個(gè)難題,從我的經(jīng)驗(yàn)來看,特別容易過擬合,而且確實(shí)有一種對(duì)抗的感覺在里面,因?yàn)間enerator的輸入時(shí)好時(shí)壞,總的來說是個(gè)很棒的算法,非常期待接下來的研究。
6. 引用
[1]?Intriguing properties of neural networks
[2]?EXPLAINING AND HARNESSING ADVERSARIAL EXAMPLES
[3]?Adversarial Examples
[4]?Breaking Linear Classifiers on ImageNet
[5]?Generative Adversarial Nets
[6]?Quick introduction to GANs
[7]?carpedm20/DCGAN-tensorflow
[8]?Generative Adversarial Nets in TensorFlow (Part I)
[9]?chenrudan/deep-learning/denoise_dcgan/
總結(jié)
以上是生活随笔為你收集整理的简述生成式对抗网络 GAN的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Part 2 – Deep analys
- 下一篇: 【David Silver强化学习公开课