生成对抗网络入门详解及TensorFlow源码实现--深度学习笔记
生成對(duì)抗網(wǎng)絡(luò)入門詳解及TensorFlow源碼實(shí)現(xiàn)–深度學(xué)習(xí)筆記
一、生成對(duì)抗網(wǎng)絡(luò)(GANs)
生成對(duì)抗網(wǎng)絡(luò)是一種生成模型(Generative Model),其背后最基本的思想就是從訓(xùn)練庫(kù)里獲取很多的訓(xùn)練樣本(Training Examples),從而學(xué)習(xí)這些訓(xùn)練案例生成的概率分布。
GAN[Goodfellow Ian,GAN]啟發(fā)自博弈論中的二人零和博弈(two-player game),由[Goodfellow et al, NIPS 2014]開創(chuàng)性地提出。在二人零和博弈中,兩位博弈方的利益之和為零或一個(gè)常數(shù),即一方有所得,另一方必有所失。GAN模型中的兩位博弈方分別由生成式模型(generative model)和判別式模型(discriminative model)充當(dāng)。生成模型G捕捉樣本數(shù)據(jù)的分布,判別模型是一個(gè)二分類器,估計(jì)一個(gè)樣本來(lái)自于訓(xùn)練數(shù)據(jù)(而非生成數(shù)據(jù))的概率。G和D一般都是非線性映射函數(shù),例如多層感知機(jī)、卷積神經(jīng)網(wǎng)絡(luò)等。
二、生成對(duì)抗網(wǎng)絡(luò)的原理
1、生成對(duì)抗過(guò)程
GANs的方法,就是讓兩個(gè)網(wǎng)絡(luò)相互競(jìng)爭(zhēng)“玩一個(gè)游戲”。
其中一個(gè)叫做生成器網(wǎng)絡(luò)( Generator Network),它不斷捕捉訓(xùn)練庫(kù)里真實(shí)圖片的概率分布,將輸入的隨機(jī)噪聲(Random Noise)轉(zhuǎn)變成新的樣本(也就是假數(shù)據(jù))。
另一個(gè)叫做判別器網(wǎng)絡(luò)(Discriminator Network),它可以同時(shí)觀察真實(shí)和假造的數(shù)據(jù),判斷這個(gè)數(shù)據(jù)到底是不是真的。
所以整個(gè)訓(xùn)練過(guò)程包含兩步,(在下圖里,判別器用 D 表示,生成器用 G 表示,真實(shí)數(shù)據(jù)庫(kù)樣本用 X 表示,噪聲用 Z 表示)。
第一步,只有判別器D參與。
我們把X樣本輸入可微函數(shù)D里運(yùn)行,D輸出0-1之間的某個(gè)值,數(shù)值越大意味著X樣本是真實(shí)的可能性越大。在這個(gè)過(guò)程中,判別器D盡可能使輸出的值靠近1,因?yàn)檫@一階段的X樣本就是真實(shí)的圖片。
第二步,判別器D和生成器G都參與。
我們首先將噪聲數(shù)據(jù)Z喂給生成器G,G從原有真實(shí)圖像庫(kù)里學(xué)習(xí)概率分布,從而產(chǎn)生假的圖像樣本。然后,我們把假的數(shù)據(jù)交給判別器D。這一次,D將盡可能輸入數(shù)值0,這代表著輸入數(shù)據(jù)Z是假的。
所以這個(gè)過(guò)程中,判別器D相當(dāng)于一個(gè)監(jiān)督情況下的二分類器,數(shù)據(jù)要么歸為1,要么歸為0。
與傳統(tǒng)神經(jīng)網(wǎng)絡(luò)訓(xùn)練不一樣的且有趣的地方,就是我們訓(xùn)練生成器的方法不同。生成器一心想要“騙過(guò)”判別器。使用博弈理論分析技術(shù),我們可以證明這里面存在一種均衡。
2、數(shù)學(xué)原理
在訓(xùn)練過(guò)程中,生成網(wǎng)絡(luò)G的目標(biāo)就是盡量生成真實(shí)的圖片去欺騙判別網(wǎng)絡(luò)D。而D的目標(biāo)就是盡量把G生成的圖片和真實(shí)的圖片分別開來(lái)。這樣,G和D構(gòu)成了一個(gè)動(dòng)態(tài)的“博弈過(guò)程”。
最后博弈的結(jié)果是什么?在最理想的狀態(tài)下,G可以生成足以“以假亂真”的圖片G(z)。對(duì)于D來(lái)說(shuō),它難以判定G生成的圖片究竟是不是真實(shí)的,因此D(G(z)) = 0.5。
這樣我們的目的就達(dá)成了:我們得到了一個(gè)生成式的模型G,它可以用來(lái)生成圖片。
以上只是大致說(shuō)了一下GAN的核心原理,如何用數(shù)學(xué)語(yǔ)言描述呢?這里直接摘錄論文里的公式:
簡(jiǎn)單分析一下這個(gè)公式:
? 整個(gè)式子由兩項(xiàng)構(gòu)成。x表示真實(shí)圖片,z表示輸入G網(wǎng)絡(luò)的噪聲,而G(z)表示G網(wǎng)絡(luò)生成的圖片。
? D(x)表示D網(wǎng)絡(luò)判斷真實(shí)圖片是否真實(shí)的概率(因?yàn)閤就是真實(shí)的,所以對(duì)于D來(lái)說(shuō),這個(gè)值越接近1越好)。而D(G(z))是D網(wǎng)絡(luò)判斷G生成的圖片的是否真實(shí)的概率。
? G的目的:上面提到過(guò),D(G(z))是D網(wǎng)絡(luò)判斷G生成的圖片是否真實(shí)的概率,G應(yīng)該希望自己生成的圖片“越接近真實(shí)越好”。也就是說(shuō),G希望D(G(z))盡可能得大,這時(shí)V(D, G)會(huì)變小。因此我們看到式子的最前面的記號(hào)是min_G。
? D的目的:D的能力越強(qiáng),D(x)應(yīng)該越大,D(G(x))應(yīng)該越小。這時(shí)V(D,G)會(huì)變大。因此式子對(duì)于D來(lái)說(shuō)是求最大(max_D)
三、GAN的優(yōu)勢(shì)與缺陷
1、優(yōu)勢(shì)
? 根據(jù)實(shí)際的結(jié)果,它們看上去可以比其它模型產(chǎn)生了更好的樣本(圖像更銳利、清晰)。
? 生成對(duì)抗式網(wǎng)絡(luò)框架能訓(xùn)練任何一種生成器網(wǎng)絡(luò)(理論上-實(shí)踐中,用 REINFORCE 來(lái)訓(xùn)練帶有離散輸出的生成網(wǎng)絡(luò)非常困難)。大部分其他的框架需要該生成器網(wǎng)絡(luò)有一些特定的函數(shù)形式,比如輸出層是高斯的。重要的是所有其他的框架需要生成器網(wǎng)絡(luò)遍布非零質(zhì)量(non-zero mass)。生成對(duì)抗式網(wǎng)絡(luò)能學(xué)習(xí)可以僅在與數(shù)據(jù)接近的細(xì)流形(thin manifold)上生成點(diǎn)。
? 不需要設(shè)計(jì)遵循任何種類的因式分解的模型,任何生成器網(wǎng)絡(luò)和任何鑒別器都會(huì)有用。
? 無(wú)需利用馬爾科夫鏈反復(fù)采樣,無(wú)需在學(xué)習(xí)過(guò)程中進(jìn)行推斷(Inference),回避了近似計(jì)算棘手的概率的難題。
2、存在的主要問題:
? 解決不收斂(non-convergence)的問題。
目前面臨的基本問題是:所有的理論都認(rèn)為 GAN 應(yīng)該在納什均衡(Nash equilibrium)上有卓越的表現(xiàn),但梯度下降只有在凸函數(shù)的情況下才能保證實(shí)現(xiàn)納什均衡。當(dāng)博弈雙方都由神經(jīng)網(wǎng)絡(luò)表示時(shí),在沒有實(shí)際達(dá)到均衡的情況下,讓它們永遠(yuǎn)保持對(duì)自己策略的調(diào)整是可能的【OpenAI Ian Goodfellow的Quora】。
? 難以訓(xùn)練:崩潰問題(collapse problem)
GAN模型被定義為極小極大問題,沒有損失函數(shù),在訓(xùn)練過(guò)程中很難區(qū)分是否正在取得進(jìn)展。GAN的學(xué)習(xí)過(guò)程可能發(fā)生崩潰問題(collapse problem),生成器開始退化,總是生成同樣的樣本點(diǎn),無(wú)法繼續(xù)學(xué)習(xí)。當(dāng)生成模型崩潰時(shí),判別模型也會(huì)對(duì)相似的樣本點(diǎn)指向相似的方向,訓(xùn)練無(wú)法繼續(xù)。
? 無(wú)需預(yù)先建模,模型過(guò)于自由不可控。
與其他生成式模型相比,GAN這種競(jìng)爭(zhēng)的方式不再要求一個(gè)假設(shè)的數(shù)據(jù)分布,即不需要formulate p(x),而是使用一種分布直接進(jìn)行采樣sampling,從而真正達(dá)到理論上可以完全逼近真實(shí)數(shù)據(jù),這也是GAN最大的優(yōu)勢(shì)。然而,這種不需要預(yù)先建模的方法缺點(diǎn)是太過(guò)自由了,對(duì)于較大的圖片,較多的 pixel的情形,基于簡(jiǎn)單 GAN 的方式就不太可控了。在GAN[Goodfellow Ian, Pouget-Abadie J] 中,每次學(xué)習(xí)參數(shù)的更新過(guò)程,被設(shè)為D更新k回,G才更新1回,也是出于類似的考慮。
四、DCGANs:深度卷積生成對(duì)抗網(wǎng)絡(luò)
DCGANs的基本架構(gòu)就是使用幾層“反卷積”(Deconvolution)網(wǎng)絡(luò)。“反卷積”類似于一種反向卷積,這跟用反向傳播算法訓(xùn)練監(jiān)督的卷積神經(jīng)網(wǎng)絡(luò)(CNN)是類似的操作。
CNN是將圖像的尺寸壓縮,變得越來(lái)越小,而反卷積是將初始輸入的小數(shù)據(jù)(噪聲)變得越來(lái)越大(但反卷積并不是CNN的逆向操作,這個(gè)下面會(huì)有詳解)。
如果你要把卷積核移動(dòng)不止一個(gè)位置, 使用的卷積滑動(dòng)步長(zhǎng)更大,那么在反卷積的每一層,你所得到的圖像尺寸就會(huì)越大。
這個(gè)論文里另一個(gè)重要思想,就是在大部分網(wǎng)絡(luò)層中使用了“批量規(guī)范化”(batch normalization),這讓學(xué)習(xí)過(guò)程的速度更快且更穩(wěn)定。另一個(gè)有趣的思想就是,如何處理生成器里的“池化層”(Pooling Layers),傳統(tǒng)CNN使用的池化層,往往取區(qū)域平均或最大來(lái)壓縮表征數(shù)據(jù)的尺寸。
在反卷積過(guò)程中,從代碼到最終生成圖片,表征數(shù)據(jù)變得越來(lái)越大,我們需要某個(gè)東西來(lái)逐漸擴(kuò)大表征的尺寸。但最大值池化(max-pooling)過(guò)程并不可逆,所以DCGANs那篇論文里,并沒有采用池化的逆向操作,而只是讓“反卷積”的滑動(dòng)步長(zhǎng)設(shè)定為2或更大值,這一方法確實(shí)會(huì)讓表征尺寸按我們的需求增大。
DCGANs非常擅長(zhǎng)生成特定Domain里的小圖片,這里是一些生成的“臥室”圖片樣本。這些圖片分辨率不是很高,但是你可以看到里面包含了門、窗戶、棉被、枕頭、床頭板、燈具等臥室常見物品。
五、生成對(duì)抗網(wǎng)絡(luò)應(yīng)用
1、GANs的應(yīng)用:“文本轉(zhuǎn)圖像”(Text to Image)
我們可以用GANs做很多應(yīng)用,其中一種就是“文本轉(zhuǎn)圖像”(Text to Image)。在Scott Reed等人的一篇論文里(Generative Adversarial Text to Image Synthesis,鏈接 https://arxiv.org/abs/1605.05396),GANs根據(jù)輸入的信息產(chǎn)生了相關(guān)圖像,。
也就是說(shuō),生成器里輸入的不僅是隨機(jī)噪聲,還有一些特定的語(yǔ)句信息。所以判別器不僅要區(qū)分樣本是否是真實(shí)的,還要判定其是否與輸入的語(yǔ)句信息相符。
這里是他們的實(shí)驗(yàn)結(jié)果,左上角的圖里有一些鳥,鳥的胸脯和鳥冠是是粉色,主羽和次羽是黑色,與所給語(yǔ)句描述的信息相符。
但是我們也看到,仍然存在“模型崩潰”問題,在右下角的黃白花里,確實(shí)產(chǎn)生了白色花瓣和黃色花蕊的花朵,但它們多少看起來(lái)是在同一個(gè)方向上映射出來(lái)的同一朵花,它們的花瓣數(shù)和尺寸幾乎相同。
所以,模型在輸出的多樣性方面還有些問題,這需要解決。但可喜的地方在于,輸入的語(yǔ)句信息都比較好的映射到產(chǎn)生的圖像樣本中。
2、有趣的GANs 圖像生成應(yīng)用
在Indico和Facebook發(fā)布了他們自己的DCGAN代碼之后,很多人開發(fā)出他們自己的、有趣的GANs應(yīng)用。有的生成新的花朵圖像,還有新動(dòng)漫角色。我個(gè)人最喜歡的,是一個(gè)能生成新品種精靈寶可夢(mèng)的應(yīng)用。
在一個(gè) Youtube 視頻,你會(huì)看到學(xué)習(xí)過(guò)程:生成器被迫去學(xué)習(xí)怎么騙過(guò)判別器,圖像逐漸變得更真實(shí)。有些生成的寶可夢(mèng),雖然它們是全新的品種,看上去就像真的一樣。這些圖像的真實(shí)感并沒有一些專業(yè)學(xué)術(shù)論文里面的那么強(qiáng),但對(duì)于現(xiàn)在的生成模型來(lái)說(shuō),不經(jīng)過(guò)任何額外處理就能得到這樣的結(jié)果,已經(jīng)非常不錯(cuò)了。
3、超分辨率
一篇最近發(fā)表的論文,描述怎么利用GANs進(jìn)行超分辨率重建(Super-Resolution)。我不確定這能否在本視頻中體現(xiàn)出來(lái),因?yàn)橐曨l清晰度的限制。基本思想是,你可以在有條件的GANs里,輸入低分辨率圖像,然后輸出高分版本。使用生成模型的原因在于,這是一個(gè)約束不足(underconstrained)的問題:對(duì)于任何一個(gè)低分辨率圖像,有無(wú)數(shù)種可能的高分辨率版本。相比其他生成模型,GANs特別適用超分辨率應(yīng)用。因?yàn)镚ANs的專長(zhǎng)就是創(chuàng)建極有真實(shí)感的樣本。它們并不特別擅長(zhǎng)做概率函數(shù)密度的估測(cè),但在超分辨率應(yīng)用中,我們最終關(guān)心的是輸出高分圖像,而不是概率分布。
(從左到右分別為:圖1、2、3、4)
上面展示的四幅圖像中,最左邊的是原始高分圖像(圖1),剩下的其余三張圖片都是通過(guò)對(duì)圖片的降采樣(Down Sample)生成的。我們把降采樣得到的圖片用不同的方法進(jìn)行放大,以期得到跟原始圖像同樣的品質(zhì)。
這些方法有很多種,比如我們用雙三次插值(Bicubic Interpolation)方式,生成的圖像(圖2)看起來(lái)很模糊,且對(duì)比度很低。另一個(gè)深度學(xué)習(xí)方法SRResNet(圖3)的效果更好,圖片已經(jīng)干凈了很多。但若采用GANs重建的圖片(圖4),有著比其它兩種方式更低的信噪比。雖然我們直觀上覺得圖3看起來(lái)更清晰,事實(shí)上它的信噪比更高一些。GANs在量化矩陣(Quantitative Matrix)和人眼清晰度感知兩方面,都有很好的表現(xiàn)。
六、TensorFlow源碼(生成手寫字體)
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np from skimage.io import imsave import os import shutilimg_height = 28 img_width = 28 img_size = img_height * img_widthto_train = True to_restore = False output_path = "output"# 總迭代次數(shù)500 max_epoch = 500h1_size = 150 h2_size = 300 z_size = 100 batch_size = 256# generate (model 1) def build_generator(z_prior):w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32)h3 = tf.matmul(h2, w3) + b3x_generate = tf.nn.tanh(h3)g_params = [w1, b1, w2, b2, w3, b3]return x_generate, g_params# discriminator (model 2) def build_discriminator(x_data, x_generated, keep_prob):# tf.concatx_in = tf.concat([x_data, x_generated],0)w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)h3 = tf.matmul(h2, w3) + b3y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))d_params = [w1, b1, w2, b2, w3, b3]return y_data, y_generated, d_params# def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5img_h, img_w = batch_res.shape[1], batch_res.shape[2]grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)for i, res in enumerate(batch_res):if i >= grid_size[0] * grid_size[1]:breakimg = (res) * 255img = img.astype(np.uint8)row = (i // grid_size[0]) * (img_h + grid_pad)col = (i % grid_size[1]) * (img_w + grid_pad)img_grid[row:row + img_h, col:col + img_w] = imgimsave(fname, img_grid)def train():# load data(mnist手寫數(shù)據(jù)集)mnist = input_data.read_data_sets('MNIST_data', one_hot=True)x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")keep_prob = tf.placeholder(tf.float32, name="keep_prob")global_step = tf.Variable(0, name="global_step", trainable=False)# 創(chuàng)建生成模型x_generated, g_params = build_generator(z_prior)# 創(chuàng)建判別模型y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)# 損失函數(shù)的設(shè)置d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))g_loss = - tf.log(y_generated)optimizer = tf.train.AdamOptimizer(0.0001)# 兩個(gè)模型的優(yōu)化函數(shù)d_trainer = optimizer.minimize(d_loss, var_list=d_params)g_trainer = optimizer.minimize(g_loss, var_list=g_params)init = tf.initialize_all_variables()saver = tf.train.Saver()# 啟動(dòng)默認(rèn)圖sess = tf.Session()# 初始化sess.run(init)if to_restore:chkpt_fname = tf.train.latest_checkpoint(output_path)saver.restore(sess, chkpt_fname)else:if os.path.exists(output_path):shutil.rmtree(output_path)os.mkdir(output_path)z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)steps = 60000 / batch_sizefor i in range(sess.run(global_step), max_epoch):for j in np.arange(steps): # for j in range(steps):print("epoch:%s, iter:%s" % (i, j))# 每一步迭代,我們都會(huì)加載256個(gè)訓(xùn)練樣本,然后執(zhí)行一次train_stepx_value, _ = mnist.train.next_batch(batch_size)x_value = 2 * x_value.astype(np.float32) - 1z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)# 執(zhí)行生成sess.run(d_trainer,feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})# 執(zhí)行判別if j % 1 == 0:sess.run(g_trainer,feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})show_result(x_gen_val, "output/sample{0}.jpg".format(i))z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))sess.run(tf.assign(global_step, i + 1))saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)def test():z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")x_generated, _ = build_generator(z_prior)chkpt_fname = tf.train.latest_checkpoint(output_path)init = tf.initialize_all_variables()sess = tf.Session()saver = tf.train.Saver()sess.run(init)saver.restore(sess, chkpt_fname)z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})show_result(x_gen_val, "output/test_result.jpg")if __name__ == '__main__':if to_train:train()else:test()參考文獻(xiàn)
http://blog.csdn.net/solomon1558/article/details/52549409
http://www.leiphone.com/news/201612/eAOGpvFl60EgFSwS.html
http://www.itwendao.com/article/detail/403491.html
總結(jié)
以上是生活随笔為你收集整理的生成对抗网络入门详解及TensorFlow源码实现--深度学习笔记的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tcpdump命令使用总结
- 下一篇: Ubuntu服务器上搭建solo个人博客