Tensorflow2.0实现对抗生成网络(GAN)
在這篇文章中,我們使用Tensorflow2.0來實(shí)現(xiàn)GAN,使用的數(shù)據(jù)集是手寫數(shù)字?jǐn)?shù)據(jù)集。
引入需要的庫(kù)
import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import matplotlib.pyplot as plt %matplotlib inline導(dǎo)入數(shù)據(jù),歸一化數(shù)據(jù)
(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data() train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') train_images = (train_images-127.5)/127.5BATCH_SIZE = 256 BUFFER_SIZE = 60000datasets = tf.data.Dataset.from_tensor_slices(train_images) datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)建立生成器
def generator_model(): # 用100個(gè)隨機(jī)數(shù)(噪音)生成手寫數(shù)據(jù)集model = keras.Sequential()model.add(layers.Dense(256, input_shape=(100,), use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(28*28*1, use_bias=False, activation='tanh'))model.add(layers.BatchNormalization())model.add(layers.Reshape((28, 28, 1)))return model建立判別器
def discriminator_model(): # 識(shí)別輸入的圖片model = keras.Sequential()model.add(layers.Flatten())model.add(layers.Dense(512, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(256, use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(1))return model分別定義判別器和生成器的損失函數(shù)
對(duì)于判別器來說,我們需要將導(dǎo)入的原始圖片識(shí)別為真(1),將生成器勝場(chǎng)的圖像識(shí)別為假(0)。
對(duì)于生成器來說,我們需要使得生成的圖片無(wú)限接近于真實(shí)圖片。
在以上代碼中,real_out是指向判別器輸入原始圖像得到的結(jié)果;fake_out是指向判別器輸入生成圖像得到的結(jié)果。
所以對(duì)于判別器的損失函數(shù)來說,real_out應(yīng)該無(wú)限接近于1;fake_out應(yīng)該無(wú)限接近于0。即我們想訓(xùn)練出的判別器應(yīng)該對(duì)圖片有很高的識(shí)別能力。
但對(duì)于生成器的損失函數(shù)來說,fake_out應(yīng)該無(wú)限接近于1,也就是令判別器很難分辨出生成的圖片。
【注】keras.losses.BinaryCrossentropy(from_logits=True)的用法可以參考:tensorflow2.0中損失函數(shù)tf.keras.losses.BinaryCrossentropy()的用法。
分別定義生成器和判別器的優(yōu)化函數(shù)
generator_opt = keras.optimizers.Adam(1e-4) discriminator_opt = keras.optimizers.Adam(1e-4)實(shí)例化生成器和判別器
generator = generator_model() discriminator = discriminator_model()定義訓(xùn)練過程
noise_dim = 100 # 即用100個(gè)隨機(jī)數(shù)生成圖片def train_step(images):noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:real_out = discriminator(images, training=True)gen_image = generator(noise, training=True)fake_out = discriminator(gen_image, training=True)gen_loss = generator_loss(fake_out)disc_loss = discriminator_loss(real_out, fake_out)gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)gradient_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))discriminator_opt.apply_gradients(zip(gradient_disc, discriminator.trainable_variables))gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)表示計(jì)算gen_loss對(duì)于generator的所有變量的梯度。
generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))表示根據(jù)gradient_gen來優(yōu)化generator的變量。
【注】梯度帶及梯度更新的用法參考:Tensorflow中的梯度帶(GradientTape)以及梯度更新。
定義繪圖函數(shù)
def generate_plot_image(gen_model, test_noise):pre_images = gen_model(test_noise, training=False)fig = plt.figure(figsize=(4, 4))for i in range(pre_images.shape[0]):plt.subplot(4, 4, i+1)plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap='gray')plt.axis('off')plt.show()plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap=‘gray’)
這里是因?yàn)槲覀兪褂胻anh激活函數(shù)之后會(huì)將結(jié)果限制在-1到1之間,而我們需要將其轉(zhuǎn)化到0到1之間。
定義訓(xùn)練函數(shù)
EPOCHS = 100 # 訓(xùn)練100次 num_exp_to_generate = 16 # 生成16張圖片 seed = tf.random.normal([num_exp_to_generate, noise_dim]) # 16組隨機(jī)數(shù)組,每組含100個(gè)隨機(jī)數(shù),用來生成16張圖片。def train(dataset, epochs):for epoch in range(epochs):for image_batch in dataset:train_step(image_batch)print('.', end='')generate_plot_image(generator, seed) train(datasets, EPOCHS)總結(jié)
以上是生活随笔為你收集整理的Tensorflow2.0实现对抗生成网络(GAN)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: AD20原理图转pcb失败
- 下一篇: 一种可实时处理 O(1)复杂度图像去雾算