【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务
大家好,歡迎來到專欄《百戰GAN》,我們在公眾號已經輸出了非常多的GAN相關的理論,這一次我們開設《百戰GAN》專欄,在這個專欄里,我們會進行算法的核心思想講解,代碼的詳解,模型的訓練等內容。
作者&編輯 | 言有三
本文資源與生成結果展示
本文篇幅:5000字
背景要求:會使用Python,Tensorflow或者Pytorch
附帶資料:項目推薦,版本包括Pytorch+Tensorflow
同步平臺:有三AI知識星球(一周內)
1 項目背景
生成對抗網絡如今在計算機視覺的很多領域中都被廣泛應用,需要每一個學習深度學習相關技術的算法人員掌握,我們公眾號和知識星球講述了非常多的理論知識,在這個《百戰GAN》專欄中,我們會配合各類實戰案例來幫助大家進行提升,本次項目開發需要以下環境:
(1) Linux系統或者windows系統,使用Linux效率更高。
(2)?安裝好的Tensorflow,CPU或者GPU訓練都可以。
2 原理簡介
今天我們要實踐的模型是DCGAN和CGAN,DCGAN是第一個全卷積GAN,麻雀雖小,五臟俱全,最適合新人實踐。
DCGAN的生成器和判別器都采用了4層的網絡結構。生成器網絡結構如上圖所示,輸入為1×100的向量,然后經過一個全連接層學習,reshape為4×4×1024的張量,再經過4個上采樣的反卷積網絡層,生成64×64的圖,各層的配置如下:
判別器輸入64×64大小的圖,經過4次卷積,分辨率降低為4×4的大小,每一個卷積層的配置如下:
DCGAN并不能控制生成圖片的類別,條件GAN(CGAN)則使用了條件控制變量作為輸入,是幾乎后續所有性能強大的GAN的基礎。網絡結構如下,其中的y就是條件變量。
對于生成器來說,輸入包括z和y,兩者會進行拼接后作為輸入。對于判別器來說,輸入包括了x和y,兩者會進行拼接后作為輸入,當然為了和z以及x進行拼接,y需要做一些維度變換,即reshape操作。
關于它們的理論更加詳細的講解,大家可以移步有三AI知識星球,或者自行閱讀論文。
3 模型訓練
接下來我們進行實踐,選擇tensorflow框架,下面詳解具體的工程代碼,主要包括:
(1) 生成器和判別器模型的定義。
(2) 損失和優化目標的定義。
3.1 DCGAN類定義
首先我們需要定義一個類,設計好輸入輸出,__init__函數如下:
# 模型定義
class DCGAN(object):
? ? def __init__(self, sess, input_height=108, input_width=108, crop=True,
???????? batch_size=64, sample_num = 64, output_height=64, output_width=64,
???????? y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
???????? gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
???????? max_to_keep=1,
???????? input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'):
其中參數解釋如下:sess表示TensorFlow session,batch_size即批處理大小;z_dim是噪聲的維度,默認為100;y_dim是一個可選的條件變量,比如分類標簽,用于CGAN;gf_dim是生成器第一個卷積層的通道數;df_dim是判別器第一個卷積層的通道數;gfc_dim是生成器全連接層維度;dfc_dim是判別器全連接層維度;c_dim是輸入圖像維度,灰度圖為1,彩色圖為3。
從上述代碼可以看出,初始化函數__init__中配置了訓練輸入圖尺寸,批處理大小,輸出圖尺寸,生成器的輸入維度,以及生成器和判別的卷積層和全連接層的若干維度變量。
總結
以上是生活随笔為你收集整理的【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【通知】3月第三周直播预告,模型精简前沿
- 下一篇: 【百战GAN】GAN也可以拿来做图像分割