利用GAN原始框架生成手写数字
生活随笔
收集整理的這篇文章主要介紹了
利用GAN原始框架生成手写数字
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
這一篇GAN文章只是讓產生的結果盡量真實,還不能分類。
本次手寫數字GAN的思想:
對于辨別器,利用真實的手寫數字(真樣本,對應的標簽為真標簽)和隨機噪聲經過生成器產生的樣本(假樣本,對應的標簽為假標簽)送入辨別器,分別得到兩個損失值,最小化這兩個損失值,這樣的話就能保證辨別器能分清楚真假。
而對于生成器,用產生的隨機噪聲送入生成器,產生樣本(假樣本,對應的標簽為真標簽),得到損失值,最小化損失值,注意標簽要改為真標簽,因為這樣才能以假亂真,正是辨別器和生成器博弈的過程,使得生成的數據能夠以假亂真,為什么博弈到平衡的狀態隨機噪聲能夠以假亂真呢?看公式,論證了全局最優值,就是噪聲和真實數據相等的時候。
詳細推導過程
下面看代碼:test做測試用的,不用管。
路徑:
util.py代碼:
import tensorflow as tf import numpy as np """ 從正太分布輸出隨機值 """ def xavier_init(size):in_dim=size[0]xavier_stddev=tf.sqrt(2./in_dim)return tf.random_normal(shape=size,stddev=xavier_stddev)#生成模型的輸入和參數初始化G_W1=tf.Variable(xavier_init(size=[100,128])) G_b1=tf.Variable(tf.zeros(shape=[128]))G_W2=tf.Variable(xavier_init(size=[128,784])) G_b2=tf.Variable(tf.zeros(shape=[784]))theta_G=[G_W1,G_W2,G_b1,G_b2]#判別模型的輸入和參數初始化D_W1=tf.Variable(xavier_init(size=[784,128])) D_b1=tf.Variable(tf.zeros(shape=[128]))D_W2=tf.Variable(xavier_init(size=[128,1])) D_b2=tf.Variable(tf.zeros(shape=[1]))theta_D=[D_W1,D_W2,D_b1,D_b2]""" 隨機噪聲產生 """ def sample_z(m,n):return np.random.uniform(-1.0,1.0,size=[m,n]) """ 生成模型:產生數據 """ def generator(z):G_h1=tf.nn.relu(tf.matmul(z,G_W1)+G_b1)G_log_prob=tf.matmul(G_h1, G_W2) + G_b2G_prob=tf.nn.sigmoid(G_log_prob)return G_prob""" 判別模型:真實值和概率值 """ def discriminator(x):D_h1=tf.nn.relu(tf.matmul(x,D_W1)+D_b1)D_logit=tf.matmul(D_h1, D_W2) + D_b2D_prob=tf.nn.sigmoid(D_logit)return D_prob,D_logitmain.py代碼:
import tensorflow as tf import numpy as np from GAN.TWO import util import os from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec #讀入數據 mnist=input_data.read_data_sets('./data',one_hot=True) # print(mnist)Z=tf.placeholder(tf.float32,shape=[None,100])X=tf.placeholder(tf.float32,shape=[None,784]) #喂入數據 G_sample=util.generator(Z) D_real,D_logit_real=util.discriminator(X) D_fake,D_logit_fake=util.discriminator(G_sample) #計算loss D_real_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real,labels=tf.ones_like(D_logit_real))) D_fake_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake,labels=tf.zeros_like(D_logit_fake))) D_loss=D_fake_loss+D_real_lossG_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake,labels=tf.ones_like(D_logit_fake)))D_optimizer=tf.train.AdamOptimizer().minimize(D_loss,var_list=util.theta_D) G_optimizer=tf.train.AdamOptimizer().minimize(G_loss,var_list=util.theta_G)if not os.path.exists('out/'):os.makedirs('out/') """ 畫圖 """ def plot(samples):gs=gridspec.GridSpec(4,4)gs.update(wspace=0.05,hspace=.05)for i,sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28,28),cmap='Greys_r') print("=====================開始訓練============================") with tf.Session() as sess:sess.run(tf.global_variables_initializer())for it in range(100000):X_mb,_=mnist.train.next_batch(batch_size=128)# print(X_mb)_,D_loss_curr=sess.run([D_optimizer,D_loss],feed_dict={X:X_mb,Z:util.sample_z(128,100)})_, G_loss_curr = sess.run([G_optimizer, G_loss],feed_dict={Z: util.sample_z(128, 100)})if it%1000==0:print('====================打印出生成的數據============================')samples=sess.run(G_sample,feed_dict={Z: util.sample_z(16, 100)})plot(samples)plt.show()if it%1000==0:print('iter={}'.format(it))print('D_loss={}'.format(D_loss_curr))print('G_loss={}'.format(G_loss_curr))打印結果:
迭代0次。
迭代50000次。
總結
以上是生活随笔為你收集整理的利用GAN原始框架生成手写数字的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: kettle创建mysql资源库
- 下一篇: 微信开发简单实例