styleGAN
數據集
在不同分辨率下在不同數據集上訓練的預訓練 StyleGAN 模型的集合。
| LSUN Bedrooms | ||
| LSUN Cars | ||
| LSUN Cats | ||
| CelebA HQ Faces | ||
| FFHQ Faces | ||
| Pokemon | ||
| Anime Faces | ||
| Anime Portraits | ||
| WikiArt Faces | ||
| Abstract Photos | ||
| Vases | ||
| Fireworks | ||
| Ukiyo-e Faces | ||
| Butterflies |
使用預訓練網絡
pretrained_example.py 中給出了使用預訓練 StyleGAN 生成器的最小示例。 執行時,腳本會從 Google Drive 下載一個預先訓練好的 StyleGAN 生成器,并使用它來生成圖像:
> python pretrained_example.py Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... doneGs Params OutputShape WeightShape --- --- --- --- latents_in - (?, 512) - ... images_out - (?, 3, 1024, 1024) - --- --- --- --- Total 26219627> ls results example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oPgenerate_figures.py 中給出了一個更高級的示例。 該腳本復制了我們論文中的數字,以說明樣式混合、噪聲輸入和截斷:
> python generate_figures.py results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6 results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_ results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W預訓練的網絡作為標準 pickle 文件存儲在 Google Drive 上:
# Load pre-trained network. url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:_G, _D, Gs = pickle.load(f)# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.上面的代碼下載文件并解壓它以產生 dnnlib.tflib.Network 的 3 個實例。 要生成圖像,您通常需要使用 Gs——另外兩個網絡是為了完整性而提供的。 為了使 pickle.load() 工作,您需要在 PYTHONPATH 中包含 dnnlib 源目錄,并將 tf.Session 設置為默認值。 會話可以通過調用 dnnlib.tflib.init_tf() 來初始化。
使用預訓練生成器的三種方式:
-
使用 Gs.run() 進行立即模式操作,其中輸入和輸出是 numpy 數組:
# Pick latent vector. rnd = np.random.RandomState(5) latents = rnd.randn(1, Gs.input_shape[1])# Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
第一個參數是一批形狀為 [num, 512] 的潛在向量。 第二個參數是為類標簽保留的(StyleGAN 不使用)。 其余的關鍵字參數是可選的,可用于進一步修改操作(見下文)。 輸出是一批圖像,其格式由 output_transform 參數指定。 -
使用 Gs.get_output_for() 將生成器合并為更大的 TensorFlow 表達式的一部分:
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) images = tflib.convert_images_to_uint8(images) result_expr.append(inception_clone.get_output_for(images))
上面的代碼來自metrics/frechet_inception_distance.py。 它生成一批隨機圖像并將它們直接提供給 Inception-v3 網絡,而無需將數據轉換為中間的 numpy 數組。 -
查找 Gs.components.mapping 和 Gs.components.synthesis 以訪問生成器的各個子網絡。 與 Gs 類似,子網絡表示為 dnnlib.tflib.Network 的獨立實例:
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
上面的代碼來自 generate_figures.py。 它首先使用映射網絡將一批潛在向量轉換為中間 W 空間,然后使用合成網絡將這些向量轉換為一批圖像。 dlatents 數組為合成網絡的每一層存儲相同 w 向量的單獨副本,以促進風格混合。
生成器的確切細節在 training/networks_stylegan.py 中定義(參見 G_style、G_mapping 和 G_synthesis)。 可以指定以下關鍵字參數來修改調用 run() 和 get_output_for() 時的行為:
- truncation_psi 和 truncation_cutoff 控制使用 Gs (ψ=0.7, cutoff=8) 時默認執行的截斷技巧。 可以通過設置 truncation_psi=1 或 is_validation=True 來禁用它,并且可以通過設置以變化為代價進一步提高圖像質量,例如 截斷_psi=0.5。 請注意,直接使用子網時,截斷始終處于禁用狀態。 可以使用 Gs.get_var(‘dlatent_avg’) 查找手動執行截斷技巧所需的平均 w。
- randomize_noise 確定是否對每個生成的圖像使用重新隨機化噪聲輸入(True,默認)或是否對整個 minibatch 使用特定的噪聲值(False)。 可以通過使用 [var for name, var in Gs.components.synthesis.vars.items() if name.startswith(‘noise’)] 找到的 tf.Variable 實例訪問特定值。
- 直接使用映射網絡時,您可以指定 dlatent_broadcast=None 以禁用合成網絡層上的 dlatents 自動復制。
- 運行時性能可以通過 structure=‘fixed’ 和 dtype=‘float16’ 進行微調。 前者禁用了對完全訓練生成器不需要的漸進式增長的支持,后者使用半精度浮點算法執行所有計算。
準備訓練數據集
訓練和評估腳本對存儲為多分辨率 TFRecord 的數據集進行操作。每個數據集都由一個目錄表示,該目錄包含多種分辨率的相同圖像數據,以實現高效的流式傳輸。每個分辨率都有一個單獨的 *.tfrecords 文件,如果數據集包含標簽,它們也會存儲在單獨的文件中。 默認情況下,腳本希望在 datasets//-.tfrecords 中找到數據集。 可以通過編輯 config.py 來更改目錄:
result_dir = 'results' data_dir = 'datasets' cache_dir = 'cache'要獲取 FFHQ 數據集 (datasets/ffhq),請參閱 Flickr-Faces-HQ 存儲庫。
要獲取 CelebA-HQ 數據集 (datasets/celebahq),請參閱Progressive GAN 存儲庫 。
要獲取其他數據集,包括 LSUN,請查閱其相應的項目頁面。 可以使用提供的 dataset_tool.py 將數據集轉換為多分辨率 TFRecords:
訓練網絡
設置數據集后,您可以按如下方式訓練自己的 StyleGAN 網絡:編輯 train.py 以通過- - 取消注釋或編輯特定行來指定數據集和訓練配置。
- 使用 python train.py 運行訓練腳本。
- 結果將寫入新創建的目錄 results/-。
- 培訓可能需要幾天(或幾周)才能完成,具體取決于配置。
默認情況下,train.py 被配置為使用 8 個 GPU 以 1024×1024 分辨率為 FFHQ 數據集訓練最高質量的 StyleGAN(表 1 中的配置 F)。
使用 Tesla V100 GPU 的默認配置的預期訓練時間:
評估質量和解開
我們論文中使用的質量和解開度量可以使用 run_metrics.py 進行評估。 默認情況下,腳本將評估預訓練的 FFHQ 生成器的 Fréchet 起始距離 (fid50k),并將結果寫入結果下新創建的目錄中。 可以通過取消注釋或編輯 run_metrics.py 中的特定行來更改確切的行為。
使用一個 Tesla V100 GPU 的預訓練 FFHQ 生成器的預期評估時間和結果:
請注意,由于 TensorFlow 的不確定性,確切的結果可能會因運行而異。
參考資料
ustinpinkney/awesome-pretrained-stylegan
BVlabs/stylegan
總結
- 上一篇: MOSSE
- 下一篇: [Leetcode][第322题][JA