生成对抗网络——GAN(一)
Generative adversarial network
據(jù)有關(guān)媒體統(tǒng)計(jì):CVPR2018的論文里,有三分之一的論文與GAN有關(guān)
由此可見,GAN在視覺(jué)領(lǐng)域的未來(lái)多年內(nèi),將是一片沃土(CVer們是時(shí)候入門GAN了)。而發(fā)現(xiàn)這片礦源的就是GAN之父,Goodfellow大神。
文末有基于keras的GAN代碼,有助于理解GAN的原理
生成對(duì)抗網(wǎng)絡(luò)GAN,是當(dāng)今的一大熱門研究方向。在2014年,被Goodfellow大神提出來(lái),當(dāng)時(shí)的G神還只是蒙特利爾大學(xué)的博士生而已。
GAN之父的主頁(yè):
http://www.iangoodfellow.com/
GAN的論文首次出現(xiàn)在NIPS2014上,論文地址如下:
https://arxiv.org/pdf/1406.2661.pdf
入坑GAN,首先需要理由,GAN能做什么,為什么要學(xué)GAN。
GAN的初衷就是生成不存在于真實(shí)世界的數(shù)據(jù),類似于使得 AI具有創(chuàng)造力或者想象力。應(yīng)用場(chǎng)景如下:
以上的場(chǎng)景都可以找到相應(yīng)的paper。而且GAN的用處也遠(yuǎn)不止此,期待我們繼續(xù)挖掘,是發(fā)論文的好方向哦
GAN的原理介紹
這里介紹的是原生的GAN算法,雖然有一些不足,但提供了一種生成對(duì)抗性的新思路。放心,我這篇博文不會(huì)堆一大堆公式,只會(huì)提供一種理解思路。
理解GAN的兩大護(hù)法G和D
G是generator,生成器: 負(fù)責(zé)憑空捏造數(shù)據(jù)出來(lái)
D是discriminator,判別器: 負(fù)責(zé)判斷數(shù)據(jù)是不是真數(shù)據(jù)
這樣可以簡(jiǎn)單的看作是兩個(gè)網(wǎng)絡(luò)的博弈過(guò)程。在最原始的GAN論文里面,G和D都是兩個(gè)多層感知機(jī)網(wǎng)絡(luò)。首先,注意一點(diǎn),GAN操作的數(shù)據(jù)不一定非得是圖像數(shù)據(jù),不過(guò)為了更方便解釋,我在這里用圖像數(shù)據(jù)為例解釋以下GAN:
稍微解釋以下上圖,z是隨機(jī)噪聲(就是隨機(jī)生成的一些數(shù),也就是GAN生成圖像的源頭)。D通過(guò)真圖和假圖的數(shù)據(jù)(相當(dāng)于天然label),進(jìn)行一個(gè)二分類神經(jīng)網(wǎng)絡(luò)訓(xùn)練(想各位必再熟悉不過(guò)了)。G根據(jù)一串隨機(jī)數(shù)就可以捏造一個(gè)“假圖像”出來(lái),用這些假圖去欺騙D,D負(fù)責(zé)辨別這是真圖還是假圖,會(huì)給出一個(gè)score。比如,G生成了一張圖,在D這里得分很高,那證明G是很成功的;如果D能有效區(qū)分真假圖,則G的效果還不太好,需要調(diào)整參數(shù)。GAN就是這么一個(gè)博弈的過(guò)程。
那么,GAN是怎么訓(xùn)練呢?
根據(jù)GAN的訓(xùn)練算法,我畫一張圖:
GAN的訓(xùn)練在同一輪梯度反傳的過(guò)程中可以細(xì)分為2步,先訓(xùn)練D在訓(xùn)練G;注意不是等所有的D訓(xùn)練好以后,才開始訓(xùn)練G,因?yàn)镈的訓(xùn)練也需要上一輪梯度反傳中G的輸出值作為輸入。
當(dāng)訓(xùn)練D的時(shí)候,上一輪G產(chǎn)生的圖片,和真實(shí)圖片,直接拼接在一起,作為x。然后根據(jù),按順序擺放0和1,假圖對(duì)應(yīng)0,真圖對(duì)應(yīng)1。然后就可以通過(guò),x輸入生成一個(gè)score(從0到1之間的數(shù)),通過(guò)score和y組成的損失函數(shù),就可以進(jìn)行梯度反傳了。(我在圖片上舉的例子是batch = 1,len(y)=2*batch,訓(xùn)練時(shí)通??梢匀≥^大的batch)
當(dāng)訓(xùn)練G的時(shí)候, 需要把G和D當(dāng)作一個(gè)整體,我在這里取名叫做’D_on_G’。這個(gè)整體(下面簡(jiǎn)稱DG系統(tǒng))的輸出仍然是score。輸入一組隨機(jī)向量,就可以在G生成一張圖,通過(guò)D對(duì)生成的這張圖進(jìn)行打分,這就是DG系統(tǒng)的前向過(guò)程。score=1就是DG系統(tǒng)需要優(yōu)化的目標(biāo),score和y=1之間的差異可以組成損失函數(shù),然后可以反向傳播梯度。注意,這里的D的參數(shù)是不可訓(xùn)練的。這樣就能保證G的訓(xùn)練是符合D的打分標(biāo)準(zhǔn)的。這就好比:如果你參加考試,你別指望能改變老師的評(píng)分標(biāo)準(zhǔn)
需要注意的是,整個(gè)GAN的整個(gè)過(guò)程都是無(wú)監(jiān)督的(后面會(huì)有監(jiān)督性GAN比如cGAN),怎么理解這里的無(wú)監(jiān)督呢?
這里,給的真圖是沒(méi)有經(jīng)過(guò)人工標(biāo)注的,你只知道這是真實(shí)的圖片,比如全是人臉,而系統(tǒng)里的D并不知道來(lái)的圖片是什么玩意兒,它只需要分辨真假。G也不知道自己生成的是什么玩意兒,反正就是學(xué)真圖片的樣子騙D。
正由于GAN的無(wú)監(jiān)督,在生成過(guò)程中,G就會(huì)按照自己的意思天馬行空生成一些“詭異”的圖片,可怕的是D還能給一個(gè)很高的分?jǐn)?shù)。比如,生成人臉極度扭曲的圖片。這就是無(wú)監(jiān)督目的性不強(qiáng)所導(dǎo)致的,所以在同年的NIPS大會(huì)上,有一篇論文conditional GAN就加入了監(jiān)督性進(jìn)去,將可控性增強(qiáng),表現(xiàn)效果也好很多。
from __future__ import print_function, divisionfrom keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adamimport matplotlib.pyplot as pltimport sysimport numpy as npclass GAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100optimizer = Adam(0.0002, 0.5)# Build and compile the discriminatorself.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise as input and generates imgsz = Input(shape=(self.latent_dim,))img = self.generator(z)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated images as input and determines validityvalidity = self.discriminator(img)# The combined model (stacked generator and discriminator)# Trains the generator to fool the discriminatorself.combined = Model(z, validity)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))model.summary()noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):model = Sequential()model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, _), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------# Train Discriminator# ---------------------# Select a random batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim))# Generate a batch of new imagesgen_imgs = self.generator.predict(noise)# Train the discriminatord_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------# Train Generator# ---------------------noise = np.random.normal(0, 1, (batch_size, self.latent_dim))# Train the generator (to have the discriminator label samples as valid)g_loss = self.combined.train_on_batch(noise, valid)# Plot the progressprint ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':gan = GAN()gan.train(epochs=30000, batch_size=32, sample_interval=200)
總結(jié)
以上是生活随笔為你收集整理的生成对抗网络——GAN(一)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: iOS动画-CAAnimation使用详
- 下一篇: 数据采集 复习题