GAN(生成对抗网络) 解释
生活随笔
收集整理的這篇文章主要介紹了
GAN(生成对抗网络) 解释
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
GAN (生成對抗網絡)是近幾年深度學習中一個比較熱門的研究方向,它的變種有上千種。
1.什么是GAN
GAN的英文全稱是Generative Adversarial Network,中文名是生成對抗網絡。它由兩個部分組成,生成器和鑒別器(又稱判別器),它們之間的關系可以用競爭或敵對關系來描述。
我們可以拿捕食者與被捕食者之間的例子來類似說明兩者之間的關系。在生物進化的過程中,被捕食者會慢慢演化自己的特征,使自己越來越不容易被捕食者識別捕捉到,從而達到欺騙捕食者的目的;與此同時,捕食者也會隨著被捕食者的演化來演化自己對被捕食者的識別,使自己越來越容易識別捕捉到捕食者。這樣就可以達到兩者共同進化的目的。生成器代表的是被捕食者,鑒別器代表的是捕食者。
2.GAN的原理
GAN的工作原理與上述例子還有略微的不同,GAN是已經知道最終鑒別的目標是什么,但不知道假目標是什么,它會對生成器所產生的假目標做懲罰并對真目標進行獎勵,這樣鑒別器就知道了不好的假目標與好的真目標具體是什么。生成器則是希望通過進化,產生比上一次更好的假目標,使鑒別器對自己的懲罰更小。以上是一個循環,在下一個循環中鑒別器通過學習上一個循環進化出的假目標和真目標,再次進化對假目標的懲罰,同時生成器再次進化,直到與真目標一致,結束進化。
GAN簡單代碼實現
#是一個卷積神經網絡,變量名是D,其中一層構造方式如下。 w = tf.get_variable('w', [4, 4, c_dim, num_filter], initializer=tf.truncated_normal_initializer(stddev=stddev)) dconv = tf.nn.conv2d(ddata, w, strides=[1, 2, 2, 1], padding='SAME') biases = tf.get_variable('biases', [num_filter], initializer=tf.constant_initializer(0.0)) bias = tf.nn.bias_add(dconv, biases) dconv1 = tf.maximum(bias, leak*bias)#是一個逆卷積神經網絡,變量名是G,其中一層構造方式如下。 w = tf.get_variable('w', [4, 4, num_filter, num_filter*2], initializer=tf.random_normal_initializer(stddev=stddev)) deconv = tf.nn.conv2d_transpose(gconv2, w, output_shape=[batch_size, s2, s2, num_filter], strides=[1, 2, 2, 1]) biases = tf.get_variable('biases', [num_filter], initializer=tf.constant_initializer(0.0)) bias = tf.nn.bias_add(deconv, biases) deconv1 = tf.nn.relu(bias, name=scope.name)#的網絡輸入為一個維服從-1~1均勻分布的隨機變量,這里取的是100. batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]).astype(np.float32) #的網絡輸入是一個batch的64*64的圖片, #既可以是手寫體數據也可以是的一個batch的輸出。#這個過程可以參考上圖的a狀態,判別曲線處于不夠穩定的狀態, #兩個網絡都還沒訓練好。#訓練判別網絡 #判別網絡的損失函數由兩部分組成,一部分是真實數據判別為1的損失,一部分是的輸出self.G#判別為0的損失,需要優化的損失函數定義如下。self.G = self.generator(self.z) self.D, self.D_logits = self.discriminator(self.images) self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D))) self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_))) self.d_loss = self.d_loss_real + self.d_loss_fake#然后將一個batch的真實數據batch_images,和隨機變量batch_z當做輸入,執行session更新的參數。 ##### update discriminator on real d_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1).minimize(d_loss, var_list=d_vars) ... out1 = sess.run([d_optim], feed_dict={real_images: batch_images, noise_images: batch_z})#這一步可以對比圖b,判別曲線漸漸趨于平穩。 #訓練生成網絡 #生成網絡并沒有一個獨立的目標函數,它更新網絡的梯度來源是判別網絡對偽造圖片求的梯度, #并且是在設定偽造圖片的label是1的情況下,保持判別網絡不變, #那么判別網絡對偽造圖片的梯度就是向著真實圖片變化的方向。self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_))) #然后用同樣的隨機變量batch_z當做輸入更新g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) .minimize(self.g_loss, var_list=self.g_vars) out2 = sess.run([g_optim], feed_dict={noise_images:batch_z})參考資料:
link1
link2
總結
以上是生活随笔為你收集整理的GAN(生成对抗网络) 解释的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 《数据库系统概念》学习笔记——恢复系统
- 下一篇: RBAC角色访问控制