GAN 对抗生成网络代码实现
生活随笔
收集整理的這篇文章主要介紹了
GAN 对抗生成网络代码实现
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
作報告寫了ppt,這里po上?
更完整的介紹關注專欄生成對抗網絡Generative Adversarial Network
本篇的同名博客[生成對抗網絡GAN入門指南](3)GAN的工程實踐及基礎代碼
In?[1]:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import osIn?[2]:
#該函數將給出權重初始化的方法 def variable_init(size):in_dim = size[0]#計算隨機生成變量所服從的正態分布標準差w_stddev = 1. / tf.sqrt(in_dim / 2.)return tf.random_normal(shape=size, stddev=w_stddev)In?[3]:
#定義輸入矩陣的占位符,輸入層單元為784,None代表批量大小的占位,X代表輸入的真實圖片。占位符的數值類型為32位浮點型 X = tf.placeholder(tf.float32, shape=[None, 784])#定義判別器的權重矩陣和偏置項向量,由此可知判別網絡為三層全連接網絡 D_W1 = tf.Variable(variable_init([784, 128])) D_b1 = tf.Variable(tf.zeros(shape=[128]))D_W2 = tf.Variable(variable_init([128, 1])) D_b2 = tf.Variable(tf.zeros(shape=[1]))theta_D = [D_W1, D_W2, D_b1, D_b2]#定義生成器的輸入噪聲為100維度的向量組,None根據批量大小確定 Z = tf.placeholder(tf.float32, shape=[None, 100])#定義生成器的權重與偏置項。輸入層為100個神經元且接受隨機噪聲, #輸出層為784個神經元,并輸出手寫字體圖片。生成網絡根據原論文為三層全連接網絡 G_W1 = tf.Variable(variable_init([100, 128])) G_b1 = tf.Variable(tf.zeros(shape=[128]))G_W2 = tf.Variable(variable_init([128, 784])) G_b2 = tf.Variable(tf.zeros(shape=[784]))theta_G = [G_W1, G_W2, G_b1, G_b2]In?[4]:
#定義一個可以生成m*n階隨機矩陣的函數,該矩陣的元素服從均勻分布,隨機生成的z就為生成器的輸入 def sample_Z(m, n):return np.random.uniform(-1., 1., size=[m, n])In?[5]:
#定義生成器 def generator(z):#第一層先計算 y=z*G_W1+G-b1,然后投入激活函數計算G_h1=ReLU(y),G_h1 為第二次層神經網絡的輸出激活值G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)#以下兩個語句計算第二層傳播到第三層的激活結果,第三層的激活結果是含有784個元素的向量,該向量轉化28×28就可以表示圖像G_log_prob = tf.matmul(G_h1, G_W2) + G_b2G_prob = tf.nn.sigmoid(G_log_prob)return G_probIn?[6]:
#定義判別器 def discriminator(x):#計算D_h1=ReLU(x*D_W1+D_b1),該層的輸入為含784個元素的向量D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)#計算第三層的輸出結果。因為使用的是Sigmoid函數,則該輸出結果是一個取值為[0,1]間的標量(見上述權重定義)#即判別輸入的圖像到底是真(=1)還是假(=0)D_logit = tf.matmul(D_h1, D_W2) + D_b2D_prob = tf.nn.sigmoid(D_logit)#返回判別為真的概率和第三層的輸入值,輸出D_logit是為了將其輸入tf.nn.sigmoid_cross_entropy_with_logits()以構建損失函數return D_prob, D_logitIn?[7]:
#該函數用于輸出生成圖片 def plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return fig交叉熵損失函數
函數的輸入是和,就是神經網絡模型中的矩陣,且不需要經過激活函數。而的shape和相同,即正確的標注值。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
那么該函數的表達式為
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
In?[8]:
#輸入隨機噪聲z而輸出生成樣本 G_sample = generator(Z)#分別輸入真實圖片和生成的圖片,并投入判別器以判斷真偽 D_real, D_logit_real = discriminator(X) D_fake, D_logit_fake = discriminator(G_sample)#以下為原論文的判別器損失和生成器損失,但本實現并沒有使用該損失函數 # D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) # G_loss = -tf.reduce_mean(tf.log(D_fake))# 我們使用交叉熵作為判別器和生成器的損失函數,因為sigmoid_cross_entropy_with_logits內部會對預測輸入執行Sigmoid函數, #所以我們取判別器最后一層未投入激活函數的值,即D_h1*D_W2+D_b2。 #tf.ones_like(D_logit_real)創建維度和D_logit_real相等的全是1的標注,真實圖片。 D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))#損失函數為兩部分,即E[log(D(x))]+E[log(1-D(G(z)))],將真的判別為假和將假的判別為真 D_loss = D_loss_real + D_loss_fake#同樣使用交叉熵構建生成器損失函數 G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))#定義判別器和生成器的優化方法為Adam算法,關鍵字var_list表明最小化損失函數所更新的權重矩陣 D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)In?[9]:
#選擇訓練的批量大小和隨機生成噪聲的維度 mb_size = 128 Z_dim = 100#讀取數據集MNIST,并放在當前目錄data文件夾下MNIST文件夾中,如果該地址沒有數據,則下載數據至該文件夾 mnist = input_data.read_data_sets("./data/MNIST/", one_hot=True) Extracting ./data/MNIST/train-images-idx3-ubyte.gz Extracting ./data/MNIST/train-labels-idx1-ubyte.gz Extracting ./data/MNIST/t10k-images-idx3-ubyte.gz Extracting ./data/MNIST/t10k-labels-idx1-ubyte.gzIn?[10]:
#打開一個會話運行計算圖 sess = tf.Session()#初始化所有定義的變量 sess.run(tf.global_variables_initializer())#如果當前目錄下不存在out文件夾,則創建該文件夾 if not os.path.exists('out/'):os.makedirs('out/')#初始化,并開始迭代訓練,100W次 i = 0 for it in range(20000):#每2000次輸出一張生成器生成的圖片if it % 2000 == 0:samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})fig = plot(samples)plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')i += 1plt.close(fig)#next_batch抽取下一個批量的圖片,該方法返回一個矩陣,即shape=[mb_size,784],每一行是一張圖片,共批量大小行X_mb, _ = mnist.train.next_batch(mb_size)#投入數據并根據優化方法迭代一次,計算損失后返回損失值_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})#每迭代2000次輸出迭代數、生成器損失和判別器損失if it % 2000 == 0:print('Iter: {}'.format(it))print('D loss: {:.4}'. format(D_loss_curr))print('G_loss: {:.4}'.format(G_loss_curr))print() Iter: 0 D loss: 1.671 G_loss: 1.718Iter: 2000 D loss: 0.05008 G_loss: 4.74Iter: 4000 D loss: 0.3667 G_loss: 4.85Iter: 6000 D loss: 0.3974 G_loss: 4.059Iter: 8000 D loss: 0.7007 G_loss: 2.628Iter: 10000 D loss: 0.4421 G_loss: 3.05Iter: 12000 D loss: 0.7872 G_loss: 2.562Iter: 14000 D loss: 0.7155 G_loss: 2.877Iter: 16000 D loss: 0.9827 G_loss: 2.042Iter: 18000 D loss: 0.7171 G_loss: 1.966?
總結
以上是生活随笔為你收集整理的GAN 对抗生成网络代码实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JS原型概念讲解
- 下一篇: VMware 12 安装 OS X 10