巧断梯度:单个loss实现GAN模型(附开源代码)
作者丨蘇劍林
單位丨廣州火焰信息科技有限公司
研究方向丨NLP,神經(jīng)網(wǎng)絡(luò)
個(gè)人主頁(yè)丨kexue.fm
我們知道普通的模型都是搭好架構(gòu),然后定義好 loss,直接扔給優(yōu)化器訓(xùn)練就行了。但是 GAN 不一樣,一般來(lái)說(shuō)它涉及有兩個(gè)不同的 loss,這兩個(gè) loss 需要交替優(yōu)化。
現(xiàn)在主流的方案是判別器和生成器都按照 1:1 的次數(shù)交替訓(xùn)練(各訓(xùn)練一次,必要時(shí)可以給兩者設(shè)置不同的學(xué)習(xí)率,即 TTUR),交替優(yōu)化就意味我們需要傳入兩次數(shù)據(jù)(從內(nèi)存?zhèn)鞯斤@存)、執(zhí)行兩次前向傳播和反向傳播。
如果我們能把這兩步合并起來(lái),作為一步去優(yōu)化,那么肯定能節(jié)省時(shí)間的,這也就是 GAN 的同步訓(xùn)練。
注:本文不是介紹新的 GAN,而是介紹 GAN 的新寫法,這只是一道編程題,不是一道算法題。
如果在TF中
如果是在 TensorFlow 中,實(shí)現(xiàn)同步訓(xùn)練并不困難,因?yàn)槲覀兌x好了判別器和生成器的訓(xùn)練算子了(假設(shè)為 D_solver 和 G_solver ),那么直接執(zhí)行:
就行了。這建立在我們能分別獲取判別器和生成器的參數(shù)、能直接操作 sess.run 的基礎(chǔ)上。
更通用的方法?
但是如果是 Keras 呢?Keras 中已經(jīng)把流程封裝好了,一般來(lái)說(shuō)我們沒(méi)法去操作得如此精細(xì)。
所以,下面我們介紹一個(gè)通用的技巧,只需要定義單一一個(gè) loss,然后扔給優(yōu)化器,就能夠?qū)崿F(xiàn) GAN 的訓(xùn)練。同時(shí),從這個(gè)技巧中,我們還可以學(xué)習(xí)到如何更加靈活地操作 loss 來(lái)控制梯度。
判別器的優(yōu)化
我們以 GAN 的 hinge loss 為例子,它的形式是:
注意意味著要固定 G,因?yàn)?G 本身也是有優(yōu)化參數(shù)的,不固定的話就應(yīng)該是。
為了固定G,除了“把 G 的參數(shù)從優(yōu)化器中去掉”這個(gè)方法之外,我們也可以利用 stop_gradient 去手動(dòng)固定:
這里:
這樣一來(lái),在式 (2) 中,我們雖然同時(shí)放開了 D,G 的權(quán)重,但是不斷地優(yōu)化式 (2),會(huì)變的只有 D,而 G 是不會(huì)變的,因?yàn)槲覀冇玫氖腔谔荻认陆档膬?yōu)化器,而 G 的梯度已經(jīng)被停止了,換句話說(shuō),我們可以理解為 G 的梯度被強(qiáng)行設(shè)置為 0,所以它的更新量一直都是 0。?
生成器的優(yōu)化
現(xiàn)在解決了 D 的優(yōu)化,那么 G 呢? stop_gradient 可以很方便地放我們固定里邊部分的梯度(比如 D(G(z)) 的 G(z)),但 G 的優(yōu)化是要我們?nèi)ス潭ㄍ膺叺?D,沒(méi)有函數(shù)實(shí)現(xiàn)它。但不要灰心,我們可以用一個(gè)數(shù)學(xué)技巧進(jìn)行轉(zhuǎn)化。?
首先,我們要清楚,我們想要 D(G(z)) 里邊的 G 的梯度,不想要 D 的梯度,如果直接對(duì) D(G(z)) 求梯度,那么同時(shí)會(huì)得到 D,G 的梯度。如果直接求的梯度呢?只能得到 D 的梯度,因?yàn)?G 已經(jīng)被停止了。那么,重點(diǎn)來(lái)了,將這兩個(gè)相減,不就得到單純的 G 的梯度了嗎!
現(xiàn)在優(yōu)化式 (4) ,那么 D 是不會(huì)變的,改變的是 G。?
值得一提的是,直接輸出這個(gè)式子,結(jié)果是恒等于 0,因?yàn)閮刹糠侄际且粯拥?#xff0c;直接相減自然是 0,但它的梯度不是 0。也就是說(shuō),這是一個(gè)恒等于 0 的 loss,但是梯度卻不恒等于 0。?
合成單一loss?
好了,現(xiàn)在式 (2) 和式 (4) 都同時(shí)放開了 D,G,大家都是 arg min,所以可以將兩步合成一個(gè) loss:
寫出這個(gè) loss,就可以同時(shí)完成判別器和生成器的優(yōu)化了,而不需要交替訓(xùn)練,但是效果基本上等效于 1:1 的交替訓(xùn)練。引入 λ 的作用,相當(dāng)于讓判別器和生成器的學(xué)習(xí)率之比為 1:λ。
參考代碼:
https://github.com/bojone/gan/blob/master/gan_one_step_with_hinge_loss.py
文章小結(jié)
文章主要介紹了實(shí)現(xiàn) GAN 的一個(gè)小技巧,允許我們只寫單個(gè)模型、用單個(gè) loss 就實(shí)現(xiàn) GAN 的訓(xùn)練。它本質(zhì)上就是用 stop_gradient 來(lái)手動(dòng)控制梯度的技巧,在其他任務(wù)上也可能用得到它。
所以,以后我寫 GAN 都用這種寫法了,省力省時(shí)。當(dāng)然,理論上這種寫法需要多耗些顯存,這也算是犧牲空間換時(shí)間吧。
點(diǎn)擊以下標(biāo)題查看作者其他文章:?
變分自編碼器VAE:原來(lái)是這么一回事 | 附開源代碼
再談變分自編碼器VAE:從貝葉斯觀點(diǎn)出發(fā)
變分自編碼器VAE:這樣做為什么能成?
從變分編碼、信息瓶頸到正態(tài)分布:論遺忘的重要性
深度學(xué)習(xí)中的互信息:無(wú)監(jiān)督提取特征
全新視角:用變分推斷統(tǒng)一理解生成模型
細(xì)水長(zhǎng)flow之NICE:流模型的基本概念與實(shí)現(xiàn)
細(xì)水長(zhǎng)flow之f-VAEs:Glow與VAEs的聯(lián)姻
深度學(xué)習(xí)中的Lipschitz約束:泛化與生成模型
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢??答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來(lái)。
??來(lái)稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來(lái)稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請(qǐng)?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請(qǐng)單獨(dú)在附件中發(fā)送?
? 請(qǐng)留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們?cè)诰庉嫲l(fā)布時(shí)和作者溝通
?
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點(diǎn)擊 |?閱讀原文?| 查看作者博客
總結(jié)
以上是生活随笔為你收集整理的巧断梯度:单个loss实现GAN模型(附开源代码)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: AAAI 2019 论文解读 | 基于区
- 下一篇: Self-Attention GAN 中