深度学习之基于GAN实现手写数字生成
在弄畢設的時候,室友的畢設是基于DCGAN實現音樂的自動生成。那是第一次接觸對抗神經網絡,當時聽室友的描述就是兩個CNN,一個生成一個監測,在互相博弈。
最近我關注的一個大神在弄有關于GAN的東西,所以就跟著學了一下,蠻有意思的,和之前的深度學習略有不同。
1.導入庫
import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import glob import sys,os,pathlib,imageio2.基本原理
生成式對抗網絡(GAN)是一種深度學習模型,是近年來復雜分布上無監督學習最具前景的方法之一。2014年由lanGoodfellow引入深度學習領域,被評價為“20年來深度學習領域最酷的想法”。
機器學習的模型大體上可分為兩類,生成模型和判別模型。判別模型需要輸入變量,通過某種模型來預測。生成模型是給定某種隱含信息,來隨機產生觀測數據。在之前的深度學習實驗中,都是使用判別模型,來實現對某種事務的判別,例如:貓狗大戰、鳥類識別、手寫數字識別等。而生成模型接觸的并不多。GAN是更好的生成模型。
GAN主要包括了兩個部分:生成器generator與判別器discriminator。生成器主要用來學習真實圖像分布從而讓自身生成的圖像更加真實,從而騙過判別器。而判別器則需要對接收的圖片進行真假判別。
在訓練過程中,生成器努力地令生成的圖像更加真實,而判別器則努力地去識別圖像的真假,這個過程相當于二人博弈,隨著時間的推移,生成器和判別器在不斷地進行對抗。最終兩個網絡達到了一個動態均衡:生成器生成的圖像接近于真是圖像分布,而判別器識別不出真假圖像,對于給定圖像的預測為真的概率基本接近0.5(相當于隨機猜測類別)。
利用GAN生成手寫數字識別的流程圖如下所示:
對于給定的真實圖片,判別器要為其打上標簽1;
對于給定的生成圖片,判別器要為其打上標簽0;
對于生成器傳給辨別器的生成圖片,生成器希望辨別器打上標簽1.
GAN步驟:
1.生成器(Generator)接收隨機數并返回生成圖像。
2.將生成的數字圖像與實際數據集中的數字圖像一起送到鑒別器(Discriminator)。
3.鑒別器(Discriminator)接收真實和假圖像并返回概率,0到1之間的數字,1表示真,0表示假。
3.數據準備
在這一階段我們導入真實的手寫數字,對其進行打亂、batch、歸一化等操作。
(train_images,train_labels) ,(_,_) = tf.keras.datasets.mnist.load_data() train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32') train_images = (train_images - 127.5)/127.5#歸一化到[-1,1]之間 batch_size = 256 buffer_size = 60000 datasets = tf.data.Dataset.from_tensor_slices(train_images) datasets = datasets.shuffle(buffer_size).batch(batch_size)4.生成器與判別器的構建
def Generator_model():#最終生成28*28*1的圖片model = tf.keras.Sequential([tf.keras.layers.Dense(256,input_shape=(100,)),#傳入的數據為長度為100的隨機向量tf.keras.layers.BatchNormalization(),#歸一化tf.keras.layers.LeakyReLU(),#高級一點的Relu函數tf.keras.layers.Dense(512),tf.keras.layers.BatchNormalization(),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dense(28*28*1,activation='tanh'),tf.keras.layers.BatchNormalization(),tf.keras.layers.Reshape((28,28,1))#最后調整為(28,28,1)形狀的數據,與手寫數字的shape一致,作為生成器生成的圖片])return modeldef Discriminator_model():#判斷圖片是真正的圖片還是生成的model = tf.keras.Sequential([tf.keras.layers.Flatten(),#傳入一張圖片,將其展開成一維數組tf.keras.layers.Dense(512),tf.keras.layers.BatchNormalization(),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dense(256),tf.keras.layers.BatchNormalization(),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dense(1,activation='sigmoid')])return model generator = Generator_model() discriminator = Discriminator_model()5.生成器與判別器的loss構建
判別器的loss值:判斷真實圖片為1的loss與判斷生成圖片為0的loss之和。因為判別器希望將真實圖片判別為1,將生成圖片判別為0.
生成器的loss值:判斷生成圖片為1的loss。因為生成器希望生成的圖片是真實圖片,即判別為1.
參數設置
epochs = 100 noise_dim = 100 num_exp_to_generate = 16 seed = tf.random.normal([num_exp_to_generate,noise_dim])#16個長度為100的向量6.批次訓練
對一個batch_size的數據進行訓練
def train_step(images):noise = tf.random.normal([batch_size,noise_dim])#生成一個batch_size*noise_dim的數據,相當于生成了batch_size個長度為100的隨機向量with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:#兩個Tape,一個代表生成器,一個代表判別器。real_out = discriminator(images,training = True)#利用判別器對真實的圖片進行訓練,得到一個modelgen_image = generator(noise,training = True)#利用生成器對噪聲數據生成圖片fake_out = discriminator(gen_image, training=True)#利用判別器對生成的圖片進行訓練gen_loss = Generator_loss(fake_out)#利用判別器對生成圖片的判斷計算生成器的loss值dis_loss = Discriminator_loss(real_out,fake_out)##利用判別器對生成圖片和真實圖片的判斷計算判別器的loss值gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)#根據生成器的loss值和網絡模型計算梯度gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)#根據判別器的loss值和網絡模型計算梯度Generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))#根據梯度對生成器進行梯度更新Discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))#根據梯度對判別器進行梯度更新7.訓練&&可視化
def train(dataset,epochs):for epoch in range(epochs):#一共訓練epochs次for image_batch in dataset:#對dataset中的每一個batch進行訓練train_step(image_batch)print('.',end='')print()Generator_plot_image(generator,seed,epoch)#根據訓練好的生成器,對之前生成的seed進行處理,生成圖片 train(datasets,epochs) def Generator_plot_image(gen_model,test_noise,epoch):pre_images = gen_model(test_noise,training = False)#根據test_noise生成圖片,生成器設置為不可訓練fig = plt.figure(figsize=(4,4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray')#之前歸一化為[-1,1]之間,現在+1然后除以2,使之在[0,1]之間plt.axis('off')fig.savefig("E:/tmp/.keras/datasets/number_gen/%05d.png" % epoch)plt.close()生成圖片如下所示:
8.生成動圖
該模塊參考大神K同學啊
def compose_gif():# 圖片地址data_dir = "E:/tmp/.keras/datasets/number_gen"data_dir = pathlib.Path(data_dir)paths = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("E:/tmp/.keras/datasets/test.gif", gif_images, fps=2) compose_gif()文件太大,csdn忍不了無法上傳。
由于訓練速度等原因,epochs設置的是100,最終展示的效果并不是很好,但是也可以看出生成的圖片由一片模糊向逐漸清晰的過渡。
努力加油a啊
總結
以上是生活随笔為你收集整理的深度学习之基于GAN实现手写数字生成的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习之眼睛状态识别混淆矩阵的绘制
- 下一篇: 深度学习之基于DCGAN实现手写数字生成