circlegan_CycleGAN原理以及代码全解析
許多名畫造假者費盡畢生的心血,試圖模仿出藝術名家的風格。如今,CycleGAN就可以初步實現這個神奇的功能。這個功能就是風格遷移,比如下圖,照片可以被賦予莫奈,梵高等人的繪畫風格
這屬于是無配對數據(unpaired)產生的圖片,也就是說你有一些名人名家的作品,也有一些你想轉換風格的真實圖片,這兩種圖片是沒有任何交集的。在之前的文章(用AI增強人類想象力)中提到的Pix2Pix方法的關鍵是提供了在這兩個域中有相同數據的訓練樣本。CycleGAN的創新點在于能夠在源域和目標域之間,無須建立訓練數據間一對一的映射,就可實現這種遷移
想要做到這點,有兩個比較重要的點,第一個就是雙判別器。如上圖a所示,兩個分布X,Y,生成器G,F分別是X到Y和Y到X的映射,兩個判別器Dx,Dy可以對轉換后的圖片進行判別。第二個點就是cycle-consistency loss,用數據集中其他的圖來檢驗生成器,這是防止G和F過擬合,比如想把一個小狗照片轉化成梵高風格,如果沒有cycle-consistency loss,生成器可能會生成一張梵高真實畫作來騙過Dx,而無視輸入的小狗。
需要注意的是,廣為流傳的下圖,有個容易讓人理解錯誤的地方,那就是下圖中的input和output那幾張圖,兩匹馬應該除了花紋其他一致的,除此之外,結構還是挺清晰的
對抗損失
生成器和判別器的loss函數和GAN是一樣的,判別器D盡力檢測出生成器G產生的假圖片,生成器盡力生成圖片騙過判別器,具體數理推導可以看我專欄之前的文章李剛:GAN 對抗生成網絡入門輔助理解?zhuanlan.zhihu.com
對抗loss由兩部分組成:
以及
Cycle Consistency 損失
作者說:理論上,對抗訓練可以學習映射輸出G和F,它們分別作為目標域Y和X產生相同的分布。然而,具有足夠大的容量,網絡可以將相同的輸入圖像集合映射到目標域中的任何圖像的隨機排列。因此,單獨的對抗性loss不能保證可以映射單個輸入。需要另外來一個loss,保證G和F不僅能滿足各自的判別器,還能應用于其他圖片。也就是說,G和F可能合伙偷懶騙人,給G一個圖,G偷偷把小狗變成梵高自畫像,F再把梵高自畫像變成輸入。Cycle Consistency loss的到來制止了這種投機取巧的行為,他用梵高其他的畫作測試FG,用另外真實照片測試GF,看看能否變回到原來的樣子,這樣保證了GF在整個X,Y分布區間的普適性。
整體
所以,整個loss就是下面的式子,就像訓練兩個auoto-encoder一樣
作者在后文比對了單獨拿出不同部分的效果,比如只用Cycle Consistency loss,只用對抗,GAN + 前向cycle-consistency loss (F(G(x)) ≈ x),, GAN + 后向 cycle-consistency loss (G(F(y)) ≈ y),以及cycleGAN的效果。
代碼實現
首先是一些參數
ngf = 32 # Number of filters in first layer of generator
ndf = 64 # Number of filters in first layer of discriminator
batch_size = 1 # batch_size
pool_size = 50 # pool_size
img_width = 256 # Imput image will of width 256
img_height = 256 # Input image will be of height 256
img_depth = 3 # RGB format
構造生成器Generator(Encoder+Transformer+Decoder)
假設所有圖片都是256*256的彩圖,需要先用卷積神經網絡提取特征,在這里,input_gen是輸入圖像,num_features是我們從卷積層中提取出的輸出特征的數量(濾波器的數量)window_width,window_height代表濾波器尺寸。 stride_width,strideheight是濾波器如何在整個圖上移動的參數。輸出的O_C1是尺寸[256,256,32]的矩陣。也可以在后邊自行添加Relu等函數。
o_c1 = general_conv2d(input_gen,
num_features=ngf,
window_width=7,
window_height=7,
stride_width=1,
stride_height=1)
#定義卷積層函數
def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d(inputconv, num_features, [window_width, window_height], [stride_width, stride_height],
padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=tf.constant_initializer(0.0))
后面是相似的卷積步驟,最后一層輸出o_enc_A是(64,64,256)的矩陣
o_c2 = general_conv2d(o_c1, num_features=64*2, window_width=3, window_height=3, stride_width=2, stride_height=2)
# o_c2.shape = (128, 128, 128)
o_enc_A = general_conv2d(o_c2, num_features=64*4, window_width=3, window_height=3, stride_width=2, stride_height=2)
# o_enc_A.shape = (64, 64, 256)
Transformer可以將這些層視為圖像的不同附近特征的組合,然后基于這些特征來決定如何將圖像的特征向量轉換到另一個分布。作者使用了6層resnet塊,其中輸入的殘差被添加到輸出中。這樣做是為了確保先前層的輸入的屬性也可用于以后的層,因此它們的輸出不會偏離原始輸入,否則原始圖像的特性將不被保留在輸出中。任務的主要目的之一是保留原始輸入的特性,如對象的大小和形狀,因此殘差網絡非常適合這些類型的變換。關于resnet,詳見 ResNet原理及其在TF-Slim中的實現
o_r1 = build_resnet_block(o_enc_A, num_features=64*4)
o_r2 = build_resnet_block(o_r1, num_features=64*4)
o_r3 = build_resnet_block(o_r2, num_features=64*4)
o_r4 = build_resnet_block(o_r3, num_features=64*4)
o_r5 = build_resnet_block(o_r4, num_features=64*4)
o_enc_B = build_resnet_block(o_r5, num_features=64*4)
#定義resnet
def resnet_blocks(input_res, num_features):
out_res_1 = general_conv2d(input_res, num_features,
window_width=3,
window_heigth=3,
stride_width=1,
stride_heigth=1)
out_res_2 = general_conv2d(out_res_1, num_features,
window_width=3,
window_heigth=3,
stride_width=1,
stride_heigth=1)
return (out_res_2 + input_res)
下面是decoder,用反卷積把這些特征變回成圖片
o_d1 = general_deconv2d(o_enc_B, num_features=ngf*2 window_width=3, window_height=3, stride_width=2, stride_height=2)
o_d2 = general_deconv2d(o_d1, num_features=ngf, window_width=3, window_height=3, stride_width=2, stride_height=2)
gen_B = general_conv2d(o_d2, num_features=3, window_width=7, window_height=7, stride_width=1, stride_height=1)
#定義反卷積層
def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))
if do_norm:
conv = instance_norm(conv)
# conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm")
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv,"relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
判別器的構成在這里救不贅述了,無非就是用CNN把生成的圖片變成一些特征圖,再用全連接變成最后的decision(真或假)
定義loss function
判別器loss:loss_1是對于真圖的判定,越接近1越好,loss_2是對于假圖的判定,越接近0越好,loss是兩個loss相加
D_A_loss_1 = tf.reduce_mean(tf.squared_difference(dec_A,1))
D_B_loss_1 = tf.reduce_mean(tf.squared_difference(dec_B,1))
D_A_loss_2 = tf.reduce_mean(tf.square(dec_gen_A))
D_B_loss_2 = tf.reduce_mean(tf.square(dec_gen_B))
D_A_loss = (D_A_loss_1 + D_A_loss_2)/2
D_B_loss = (D_B_loss_1 + D_B_loss_2)/2
生成器loss:
g_loss_B_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))
g_loss_A_1 = tf.reduce_mean(tf.squared_difference(dec_gen_A,1))
Cycle Consistency loss: 保證原始圖像和循環圖像之間的差異應該盡可能小,注意10*cyc_loss是賦予Cycle Consistency loss更大的權值,作者并沒有討論這個參數是怎么確定下來的
cyc_loss = tf.reduce_mean(tf.abs(input_A-cyc_A)) + tf.reduce_mean(tf.abs(input_B-cyc_B))
g_loss_A = g_loss_A_1 + 10*cyc_loss
g_loss_B = g_loss_B_1 + 10*cyc_loss
模型訓練
for epoch in range(0,100):
# Define the learning rate schedule. The learning rate is kept
# constant upto 100 epochs and then slowly decayed
if(epoch < 100) :
curr_lr = 0.0002
else:
curr_lr = 0.0002 - 0.0002*(epoch-100)/100
# Running the training loop for all batches
for ptr in range(0,num_images):
# Train generator G_A->B
_, gen_B_temp = sess.run([g_A_trainer, gen_B],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
# We need gen_B_temp because to calculate the error in training D_B
_ = sess.run([d_B_trainer],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
# Same for G_B->A and D_A as follow
_, gen_A_temp = sess.run([g_B_trainer, gen_A],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
_ = sess.run([d_A_trainer],
feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的circlegan_CycleGAN原理以及代码全解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: linux mongo 服务器,如何用M
- 下一篇: onclick进不去ajax,在ajax