CGAN
詳解GAN代碼之搭建并詳解CGAN代碼
本文鏈接:https://blog.csdn.net/jiongnima/article/details/80209239
訓練數據集:填充輪廓->建筑照片
下載鏈接:https://pan.baidu.com/s/1xUg8AC7NEXyKebSUNtRvdg?密碼:2kw1
? ?CGAN是Conditional?Generative?Adversarial?Nets的縮寫,也稱為條件生成對抗網絡。條件生成對抗網絡指的是在生成對抗網絡中加入條件(condition),條件的作用是監督生成對抗網絡。本篇博客通過簡單代碼搭建,向大家解析了條件生成對抗網絡CGAN。
? ?在開始解析CGAN代碼之前,筆者想說的是,要理解CGAN,還請大家先明了CGAN的原理,筆者在下面提供一些筆者認為比較好的了解CGAN原理的鏈接:
(1) 直接進行論文閱讀:https://arxiv.org/abs/1411.1784
(2) 可以翻閱站內的一篇博客,筆者認為寫得很不錯:Conditional Generative Adversarial Nets論文筆記
(3) 筆者也簡單解析一下CGAN的原理,原理圖如下(截圖來自CGAN論文)
? ?如上圖所示,和原始的生成對抗網絡相比,條件生成對抗網絡CGAN在生成器的輸入和判別器的輸入中都加入了條件y。這個y可以是任何類型的數據(可以是類別標簽,或者其他類型的數據等)。目的是有條件地監督生成器生成的數據,使得生成器生成結果的方式不是完全自由無監督的。
? ?CGAN訓練的目標函數如下圖所示:
? ?從上面的目標函數中可以看到,條件y不僅被送入了判別器的輸入中,也被融入了生成器的輸入中。下面,筆者就來解析CGAN的代碼,首先還是列舉一下筆者主要使用的工具和庫。
?
(1) Python 3.5.2
(2) numpy
(3) Tensorflow 1.2
(4) argparse 用來解析命令行參數
(5) random 用來打亂輸入順序
(6) os 用來讀取圖片路徑和文件名
(7) glob 用來讀取圖片路徑和文件名
(8) cv2 用來讀取圖片
? ?筆者搭建的CGAN代碼分成4大部分,分別是:
(1) train.py 訓練的主控程序
(2) image_reader.py 數據讀取接口
(3) net.py 定義網絡結構
(4) evaluate.py 測試的主控程序
? ?其中,訓練時使用到的文件是(1),(2),(3)項,測試時使用到的文件時(2),(3),(4)。
? ?下面,筆者放出代碼與注釋:
首先是train.py文件中的代碼:
from __future__ import print_functionimport argparse from random import shuffle import random import os import sys import math import tensorflow as tf import glob import cv2from image_reader import * from net import *parser = argparse.ArgumentParser(description='')parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots") #保存模型的路徑 parser.add_argument("--out_dir", default='./train_out', help="path of train outputs") #訓練時保存可視化輸出的路徑 parser.add_argument("--image_size", type=int, default=256, help="load image size") #網絡輸入的尺度 parser.add_argument("--random_seed", type=int, default=1234, help="random seed") #隨機數種子 parser.add_argument('--base_lr', type=float, default=0.0002, help='initial learning rate for adam') #學習率 parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch') #訓練的epoch數量 parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') #adam優化器的beta1參數 parser.add_argument("--summary_pred_every", type=int, default=200, help="times to summary.") #訓練中每過多少step保存訓練日志(記錄一下loss值) parser.add_argument("--write_pred_every", type=int, default=100, help="times to write.") #訓練中每過多少step保存可視化結果 parser.add_argument("--save_pred_every", type=int, default=5000, help="times to save.") #訓練中每過多少step保存模型(可訓練參數) parser.add_argument("--lamda_l1_weight", type=float, default=0.0, help="L1 lamda") #訓練中L1_Loss前的乘數 parser.add_argument("--lamda_gan_weight", type=float, default=1.0, help="GAN lamda") #訓練中GAN_Loss前的乘數 parser.add_argument("--train_picture_format", default='.png', help="format of training datas.") #網絡訓練輸入的圖片的格式(圖片在CGAN中被當做條件) parser.add_argument("--train_label_format", default='.jpg', help="format of training labels.") #網絡訓練輸入的標簽的格式(標簽在CGAN中被當做真樣本) parser.add_argument("--train_picture_path", default='./dataset/train_picture/', help="path of training datas.") #網絡訓練輸入的圖片路徑 parser.add_argument("--train_label_path", default='./dataset/train_label/', help="path of training labels.") #網絡訓練輸入的標簽路徑args = parser.parse_args() #用來解析命令行參數 EPS = 1e-12 #EPS用于保證log函數里面的參數大于零def save(saver, sess, logdir, step): #保存模型的save函數model_name = 'model' #保存的模型名前綴checkpoint_path = os.path.join(logdir, model_name) #模型的保存路徑與名稱if not os.path.exists(logdir): #如果路徑不存在即創建os.makedirs(logdir)saver.save(sess, checkpoint_path, global_step=step) #保存模型print('The checkpoint has been created.')def cv_inv_proc(img): #cv_inv_proc函數將讀取圖片時歸一化的圖片還原成原圖img_rgb = (img + 1.) * 127.5return img_rgb.astype(np.float32) #返回bgr格式的圖像,方便cv2寫圖像def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函數得到訓練過程中的可視化結果picture_image = cv_inv_proc(picture) #還原輸入的圖像gen_label_image = cv_inv_proc(gen_label[0]) #還原生成的樣本label_image = cv_inv_proc(label) #還原真實的樣本(標簽)inv_picture_image = cv2.resize(picture_image, (width, height)) #還原圖像的尺寸inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #還原生成的樣本的尺寸inv_label_image = cv2.resize(label_image, (width, height)) #還原真實的樣本的尺寸output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #把他們拼起來return outputdef l1_loss(src, dst): #定義l1_lossreturn tf.reduce_mean(tf.abs(src - dst))def main(): #訓練程序的主函數if not os.path.exists(args.snapshot_dir): #如果保存模型參數的文件夾不存在則創建os.makedirs(args.snapshot_dir)if not os.path.exists(args.out_dir): #如果保存訓練中可視化輸出的文件夾不存在則創建os.makedirs(args.out_dir)train_picture_list = glob.glob(os.path.join(args.train_picture_path, "*")) #得到訓練輸入圖像路徑名稱列表tf.set_random_seed(args.random_seed) #初始一下隨機數train_picture = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_picture') #輸入的訓練圖像train_label = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_label') #輸入的與訓練圖像匹配的標簽gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的輸出dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False, name="discriminator") #判別器返回的對真實標簽的判別結果dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True, name="discriminator") #判別器返回的對生成(虛假的)標簽判別結果gen_loss_GAN = tf.reduce_mean(-tf.log(dis_fake + EPS)) #計算生成器損失中的GAN_loss部分gen_loss_L1 = tf.reduce_mean(l1_loss(gen_label, train_label)) #計算生成器損失中的L1_loss部分gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L1 * args.lamda_l1_weight #計算生成器的lossdis_loss = tf.reduce_mean(-(tf.log(dis_real + EPS) + tf.log(1 - dis_fake + EPS))) #計算判別器的lossgen_loss_sum = tf.summary.scalar("gen_loss", gen_loss) #記錄生成器loss的日志dis_loss_sum = tf.summary.scalar("dis_loss", dis_loss) #記錄判別器loss的日志summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) #日志記錄器g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name] #所有生成器的可訓練參數d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name] #所有判別器的可訓練參數d_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #判別器訓練器g_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #生成器訓練器d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) #計算判別器參數梯度d_train = d_optim.apply_gradients(d_grads_and_vars) #更新判別器參數g_grads_and_vars = g_optim.compute_gradients(gen_loss, var_list=g_vars) #計算生成器參數梯度g_train = g_optim.apply_gradients(g_grads_and_vars) #更新生成器參數train_op = tf.group(d_train, g_train) #train_op表示了參數更新操作config = tf.ConfigProto()config.gpu_options.allow_growth = True #設定顯存不超量使用sess = tf.Session(config=config) #新建會話層init = tf.global_variables_initializer() #參數初始化器sess.run(init) #初始化所有可訓練參數saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) #模型保存器counter = 0 #counter記錄訓練步數for epoch in range(args.epoch): #訓練epoch數shuffle(train_picture_list) #每訓練一個epoch,就打亂一下輸入的順序for step in range(len(train_picture_list)): #每個訓練epoch中的訓練step數counter += 1picture_name, _ = os.path.splitext(os.path.basename(train_picture_list[step])) #獲取不包含路徑和格式的輸入圖片名稱#讀取一張訓練圖片,一張訓練標簽,以及相應的高和寬picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name, picture_path=args.train_picture_path, label_path=args.train_label_path, picture_format = args.train_picture_format, label_format = args.train_label_format, size = args.image_size)batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis = 0) #填充維度batch_label = np.expand_dims(np.array(label_resize).astype(np.float32), axis = 0) #填充維度feed_dict = { train_picture : batch_picture, train_label : batch_label } #構造feed_dictgen_loss_value, dis_loss_value, _ = sess.run([gen_loss, dis_loss, train_op], feed_dict=feed_dict) #得到每個step中的生成器和判別器lossif counter % args.save_pred_every == 0: #每過save_pred_every次保存模型save(saver, sess, args.snapshot_dir, counter)if counter % args.summary_pred_every == 0: #每過summary_pred_every次保存訓練日志gen_loss_sum_value, discriminator_sum_value = sess.run([gen_loss_sum, dis_loss_sum], feed_dict=feed_dict)summary_writer.add_summary(gen_loss_sum_value, counter)summary_writer.add_summary(discriminator_sum_value, counter)if counter % args.write_pred_every == 0: #每過write_pred_every次寫一下訓練的可視化結果gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #run出生成器的輸出write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到訓練的可視化結果write_image_name = args.out_dir + "/out"+ str(counter) + ".png" #待保存的訓練可視化結果路徑與名稱cv2.imwrite(write_image_name, write_image) #保存訓練的可視化結果print('epoch {:d} step {:d} \t gen_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, gen_loss_value, dis_loss_value))if __name__ == '__main__':main()然后是image_reader.py文件:
import os import numpy as np import tensorflow as tf import cv2#讀取圖片的函數,接收六個參數 #輸入參數分別是圖片名,圖片路徑,標簽路徑,圖片格式,標簽格式,需要調整的尺寸大小 def ImageReader(file_name, picture_path, label_path, picture_format = ".png", label_format = ".jpg", size = 256):picture_name = picture_path + file_name + picture_format #得到圖片名稱和路徑label_name = label_path + file_name + label_format #得到標簽名稱和路徑picture = cv2.imread(picture_name, 1) #讀取圖片label = cv2.imread(label_name, 1) #讀取標簽height = picture.shape[0] #得到圖片的高width = picture.shape[1] #得到圖片的寬picture_resize_t = cv2.resize(picture, (size, size)) #調整圖片的尺寸,改變成網絡輸入的大小picture_resize = picture_resize_t / 127.5 - 1. #歸一化圖片label_resize_t = cv2.resize(label, (size, size)) #調整標簽的尺寸,改變成網絡輸入的大小label_resize = label_resize_t / 127.5 - 1. #歸一化標簽return picture_resize, label_resize, height, width #返回網絡輸入的圖片,標簽,還有原圖片和標簽的長寬接著是net.py文件:
import numpy as np import tensorflow as tf import math#構造可訓練參數 def make_var(name, shape, trainable = True):return tf.get_variable(name, shape, trainable = trainable)#定義卷積層 def conv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "conv2d", biased = False):input_dim = input_.get_shape()[-1]with tf.variable_scope(name):kernel = make_var(name = 'weights', shape=[kernel_size, kernel_size, input_dim, output_dim])output = tf.nn.conv2d(input_, kernel, [1, stride, stride, 1], padding = padding)if biased:biases = make_var(name = 'biases', shape = [output_dim])output = tf.nn.bias_add(output, biases)return output#定義空洞卷積層 def atrous_conv2d(input_, output_dim, kernel_size, dilation, padding = "SAME", name = "atrous_conv2d", biased = False):input_dim = input_.get_shape()[-1]with tf.variable_scope(name):kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, input_dim, output_dim])output = tf.nn.atrous_conv2d(input_, kernel, dilation, padding = padding)if biased:biases = make_var(name = 'biases', shape = [output_dim])output = tf.nn.bias_add(output, biases)return output#定義反卷積層 def deconv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "deconv2d"):input_dim = input_.get_shape()[-1]input_height = int(input_.get_shape()[1])input_width = int(input_.get_shape()[2])with tf.variable_scope(name):kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, output_dim, input_dim])output = tf.nn.conv2d_transpose(input_, kernel, [1, input_height * 2, input_width * 2, output_dim], [1, 2, 2, 1], padding = "SAME")return output#定義batchnorm(批次歸一化)層 def batch_norm(input_, name="batch_norm"):with tf.variable_scope(name):input_dim = input_.get_shape()[-1]scale = tf.get_variable("scale", [input_dim], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))offset = tf.get_variable("offset", [input_dim], initializer=tf.constant_initializer(0.0))mean, variance = tf.nn.moments(input_, axes=[1,2], keep_dims=True)epsilon = 1e-5inv = tf.rsqrt(variance + epsilon)normalized = (input_-mean)*invoutput = scale*normalized + offsetreturn output#定義lrelu激活層 def lrelu(x, leak=0.2, name = "lrelu"):return tf.maximum(x, leak*x)#定義生成器,采用UNet架構,主要由8個卷積層和8個反卷積層組成 def generator(image, gf_dim=64, reuse=False, name="generator"):input_dim = int(image.get_shape()[-1]) #獲取輸入通道dropout_rate = 0.5 #定義dropout的比例with tf.variable_scope(name):if reuse:tf.get_variable_scope().reuse_variables()else:assert tf.get_variable_scope().reuse is False#第一個卷積層,輸出尺度[1, 128, 128, 64]e1 = batch_norm(conv2d(input_=image, output_dim=gf_dim, kernel_size=4, stride=2, name='g_e1_conv'), name='g_bn_e1')#第二個卷積層,輸出尺度[1, 64, 64, 128]e2 = batch_norm(conv2d(input_=lrelu(e1), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_e2_conv'), name='g_bn_e2')#第三個卷積層,輸出尺度[1, 32, 32, 256]e3 = batch_norm(conv2d(input_=lrelu(e2), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_e3_conv'), name='g_bn_e3')#第四個卷積層,輸出尺度[1, 16, 16, 512]e4 = batch_norm(conv2d(input_=lrelu(e3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e4_conv'), name='g_bn_e4')#第五個卷積層,輸出尺度[1, 8, 8, 512]e5 = batch_norm(conv2d(input_=lrelu(e4), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e5_conv'), name='g_bn_e5')#第六個卷積層,輸出尺度[1, 4, 4, 512]e6 = batch_norm(conv2d(input_=lrelu(e5), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e6_conv'), name='g_bn_e6')#第七個卷積層,輸出尺度[1, 2, 2, 512]e7 = batch_norm(conv2d(input_=lrelu(e6), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e7_conv'), name='g_bn_e7')#第八個卷積層,輸出尺度[1, 1, 1, 512]e8 = batch_norm(conv2d(input_=lrelu(e7), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e8_conv'), name='g_bn_e8')#第一個反卷積層,輸出尺度[1, 2, 2, 512]d1 = deconv2d(input_=tf.nn.relu(e8), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d1')d1 = tf.nn.dropout(d1, dropout_rate) #隨機扔掉一般的輸出d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), e7], 3)#第二個反卷積層,輸出尺度[1, 4, 4, 512]d2 = deconv2d(input_=tf.nn.relu(d1), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d2')d2 = tf.nn.dropout(d2, dropout_rate) #隨機扔掉一般的輸出d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), e6], 3)#第三個反卷積層,輸出尺度[1, 8, 8, 512]d3 = deconv2d(input_=tf.nn.relu(d2), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d3')d3 = tf.nn.dropout(d3, dropout_rate) #隨機扔掉一般的輸出d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), e5], 3)#第四個反卷積層,輸出尺度[1, 16, 16, 512]d4 = deconv2d(input_=tf.nn.relu(d3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d4')d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), e4], 3)#第五個反卷積層,輸出尺度[1, 32, 32, 256]d5 = deconv2d(input_=tf.nn.relu(d4), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_d5')d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), e3], 3)#第六個反卷積層,輸出尺度[1, 64, 64, 128]d6 = deconv2d(input_=tf.nn.relu(d5), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_d6')d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), e2], 3)#第七個反卷積層,輸出尺度[1, 128, 128, 64]d7 = deconv2d(input_=tf.nn.relu(d6), output_dim=gf_dim, kernel_size=4, stride=2, name='g_d7')d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), e1], 3)#第八個反卷積層,輸出尺度[1, 256, 256, 3]d8 = deconv2d(input_=tf.nn.relu(d7), output_dim=input_dim, kernel_size=4, stride=2, name='g_d8')return tf.nn.tanh(d8)#定義判別器 def discriminator(image, targets, df_dim=64, reuse=False, name="discriminator"):with tf.variable_scope(name):if reuse:tf.get_variable_scope().reuse_variables()else:assert tf.get_variable_scope().reuse is Falsedis_input = tf.concat([image, targets], 3)#第1個卷積模塊,輸出尺度: 1*128*128*64h0 = lrelu(conv2d(input_ = dis_input, output_dim = df_dim, kernel_size = 4, stride = 2, name='d_h0_conv'))#第2個卷積模塊,輸出尺度: 1*64*64*128h1 = lrelu(batch_norm(conv2d(input_ = h0, output_dim = df_dim*2, kernel_size = 4, stride = 2, name='d_h1_conv'), name='d_bn1'))#第3個卷積模塊,輸出尺度: 1*32*32*256h2 = lrelu(batch_norm(conv2d(input_ = h1, output_dim = df_dim*4, kernel_size = 4, stride = 2, name='d_h2_conv'), name='d_bn2'))#第4個卷積模塊,輸出尺度: 1*32*32*512h3 = lrelu(batch_norm(conv2d(input_ = h2, output_dim = df_dim*8, kernel_size = 4, stride = 1, name='d_h3_conv'), name='d_bn3'))#最后一個卷積模塊,輸出尺度: 1*32*32*1output = conv2d(input_ = h3, output_dim = 1, kernel_size = 4, stride = 1, name='d_h4_conv')dis_out = tf.sigmoid(output) #在輸出之前經過sigmoid層,因為需要進行log運算return dis_out 上面就是訓練所需的全部代碼,大家可以看到,在net.py文件中。生成器使用UNet結構,在生成器和判別器中,image參數就是指的條件,并且在生成器的輸入中,隨機噪聲被去掉了(僅僅輸入了條件);在判別器的輸入中,條件和待判別的圖像被拼接(concat)了起來。? ?如果需要開啟訓練,可以調整train.py中的最后四個參數,根據自己的需求調整訓練輸入的圖片和標簽文件路徑和相應的格式。另外,由于CGAN訓練中需要匹配條件與判別圖片,因此,訓練讀取的圖片和標簽名稱應該是匹配的,在image_reader.py中也能看到,程序是按照同一個名稱,去檢索訓練一個批次輸入的圖像和對應的標簽。
下面是evaluate.py文件:
import argparse import sys import math import tensorflow as tf import numpy as np import glob import cv2from image_reader import * from net import *parser = argparse.ArgumentParser(description='')parser.add_argument("--test_picture_path", default='./dataset/test_picture/', help="path of test datas.")#網絡測試輸入的圖片路徑 parser.add_argument("--test_label_path", default='./dataset/test_label/', help="path of test datas.") #網絡測試輸入的標簽路徑 parser.add_argument("--image_size", type=int, default=256, help="load image size") #網絡輸入的尺度 parser.add_argument("--test_picture_format", default='.png', help="format of test pictures.") #網絡測試輸入的圖片的格式 parser.add_argument("--test_label_format", default='.jpg', help="format of test labels.") #網絡測試時讀取的標簽的格式 parser.add_argument("--snapshots", default='./snapshots/',help="Path of Snapshots") #讀取訓練好的模型參數的路徑 parser.add_argument("--out_dir", default='./test_output/',help="Output Folder") #保存網絡測試輸出圖片的路徑args = parser.parse_args() #用來解析命令行參數def cv_inv_proc(img): #cv_inv_proc函數將讀取圖片時歸一化的圖片還原成原圖img_rgb = (img + 1.) * 127.5return img_rgb.astype(np.float32) #返回bgr格式的圖像,方便cv2寫圖像def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函數得到網絡測試的結果picture_image = cv_inv_proc(picture) #還原輸入的圖像gen_label_image = cv_inv_proc(gen_label[0]) #還原生成的結果label_image = cv_inv_proc(label) #還原讀取的標簽inv_picture_image = cv2.resize(picture_image, (width, height)) #將輸入圖像還原到原大小inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #將生成的結果還原到原大小inv_label_image = cv2.resize(label_image, (width, height)) #將標簽還原到原大小output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #拼接得到輸出結果return outputdef main():if not os.path.exists(args.out_dir): #如果保存測試結果的文件夾不存在則創建os.makedirs(args.out_dir)test_picture_list = glob.glob(os.path.join(args.test_picture_path, "*")) #得到測試輸入圖像路徑名稱列表test_picture = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='test_picture') #測試輸入的圖像gen_label = generator(image=test_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的生成結果restore_var = [v for v in tf.global_variables() if 'generator' in v.name] #需要載入的已訓練的模型參數config = tf.ConfigProto()config.gpu_options.allow_growth = True #設定顯存不超量使用sess = tf.Session(config=config) #建立會話層saver = tf.train.Saver(var_list=restore_var, max_to_keep=1) #導入模型參數時使用checkpoint = tf.train.latest_checkpoint(args.snapshots) #讀取模型參數saver.restore(sess, checkpoint) #導入模型參數for step in range(len(test_picture_list)):picture_name, _ = os.path.splitext(os.path.basename(test_picture_list[step])) #得到一張網絡測試的輸入圖像名字#讀取一張測試圖片,一張標簽,以及相應的高和寬picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,picture_path=args.test_picture_path,label_path=args.test_label_path,picture_format=args.test_picture_format,label_format=args.test_label_format,size=args.image_size)batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis=0) #填充維度feed_dict = {test_picture: batch_picture} #構造feed_dictgen_label_value = sess.run(gen_label, feed_dict=feed_dict) #得到生成結果write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到一張需要存的圖像write_image_name = args.out_dir + picture_name + ".png" #為上述的圖像構造保存路徑與文件名cv2.imwrite(write_image_name, write_image) #保存測試結果print('step {:d}'.format(step))if __name__ == '__main__':main()? ?如果需要測試訓練完畢的模型,相應地更改測試圖片和標簽輸入路徑和格式的四個參數即可,并設置讀取模型權重的路徑即可。
? ?下面,筆者就以訓練的填充輪廓生成建筑圖片的例子為大家展示一下CGAN的效果:
? ?首先是訓練時的可視化輸出圖像,從左到右,第一張是網絡的輸入圖片(條件),第二張是生成器生成的建筑圖像,第三張是真實的建筑圖像(標簽)。
首先是訓練200次的輸出:
然后是訓練5600次的輸出:
然后是訓練19000次的輸出:
然后是訓練36500次的輸出:
然后是訓練46700次的輸出:
然后是訓練65700次的輸出:
然后是訓練72400次的輸出:
最后是訓練96300次的輸出:
下面展示一下訓練的loss曲線:
生成器的loss曲線:
判別器的loss曲線:
最后展示一下在測試集上面的效果:
左邊是輸入的圖像(條件),中間是生成的圖像,右邊是標簽(真實的樣本)。
? ?上面就是在測試集上面的效果,讀者朋友們可以從文章開頭筆者放出的鏈接中下載數據集進行實驗。
? ?在train.py中,如果將lamda_l1_weight參數改成100,就是pix2pix的做法,筆者放了一些測試集的效果(訓練有一些過擬合):
? ?到這里,CGAN的模型搭建及解析就接近尾聲了,很感謝Mehdi Mirza和Simon Osindero,為大家帶來條件監督的生成對抗網絡算法。CGAN還可以做很多有趣的事情,比如說這個有趣的工作:AI可能真的要代替插畫師了……,項目鏈接https://make.girls.moe/#/,通過CGAN有條件地生成二次元萌妹。
總結
- 上一篇: GAN
- 下一篇: python Demo 01 爬取大学