风格迁移--U-GAT-IT模型(ICLR 2020)
1 論文簡(jiǎn)介
論文題目: U-gat-it: Unsupervised generative attentional networks with adaptive layer-instance normalization for image-to-image translation
論文代碼:https://github.com/taki0112/UGATIT
論文數(shù)據(jù)集:https://github.com/znxlwm/UGATIT-pytorch
本文以倒序的方式來(lái)介紹這篇論文,首先看效果,然后分析其原理。
2 效果
Figure 2: Visualization of the attention maps and their effects shown in the ablation experiments: (a) Source images, (b) Attention map of the generator, (c-d) Local and global attention maps of the discriminator, respectively. (e) Our results with CAM, (f) Results without CAM.
Figure 3: Comparison of the results using each normalization function: (a) Source images, (b) Our results, ? Results only using IN in decoder with CAM, (d) Results only using LN in decoder with CAM, (e) Results only using AdaIN in decoder with CAM, (f) Results only using GN in decoder with CAM.
3 基本框架
本文提出了一種新的無(wú)監(jiān)督圖像到圖像轉(zhuǎn)換方法,以端到端的方式結(jié)合新的注意力模塊和新的可學(xué)習(xí)歸一化函數(shù)。
- 注意力模塊根據(jù)輔助分類器獲得的注意力圖,引導(dǎo)模型專注于區(qū)分源域和目標(biāo)域的更重要的區(qū)域(幫助模型知道在哪里進(jìn)行密集轉(zhuǎn)換)。 與之前無(wú)法處理域之間幾何變化的基于注意力的方法不同,本文的模型可以轉(zhuǎn)換需要整體變化的圖像和需要大形狀變化的圖像。
- AdaLIN 函數(shù)幫助注意力模型靈活控制形狀和紋理的變化量,而無(wú)需修改模型架構(gòu)或超參數(shù)。
- 實(shí)驗(yàn)結(jié)果表明,與具有固定網(wǎng)絡(luò)架構(gòu)和超參數(shù)的現(xiàn)有最先進(jìn)模型相比,所提出的方法具有優(yōu)越性。
模型分為生成器和判別器兩部分,結(jié)構(gòu)幾乎一致。生成器比判別器多了AdaLIN算法實(shí)現(xiàn)的Decoder模塊。
圖1描述了網(wǎng)絡(luò)結(jié)構(gòu),以生成器為例,輸入圖像通過(guò)Encoder編碼階段(下采樣+殘差模塊)得到特征圖,然后添加一個(gè)輔助分類引入Attention機(jī)制通過(guò)特征圖的最大池化,經(jīng)過(guò)全連接層輸出一個(gè)節(jié)點(diǎn)的預(yù)測(cè),然后將這個(gè)全連接層的參數(shù)和特征圖相乘從而得到Attention的特征圖。最后經(jīng)過(guò)Decoder模塊得到輸出圖像。
Figure 1: The model architecture of U-GAT-IT. The detailed notations are described in Section Model
本文的目標(biāo)是訓(xùn)練一個(gè)函數(shù) Gs→tG_{s \rightarrow t}Gs→t?,該函數(shù)使用從每個(gè)域中抽取未配對(duì)的樣本將圖像從源域XsX_sXs? 映射到目標(biāo)域 XtX_tXt?:
- 該框架由兩個(gè)生成器 Gs→tG_{s \rightarrow t}Gs→t?和 Gt→sG_{t \rightarrow s}Gt→s? 以及兩個(gè)鑒別器 DsD_sDs? 和DtD_tDt? 組成;
- 將注意力模塊集成到生成器和鑒別器中;
- 判別器中的注意力模塊引導(dǎo)生成器關(guān)注對(duì)生成逼真圖像至關(guān)重要的區(qū)域;
- 生成器中的注意力模塊關(guān)注與其他域不同的區(qū)域(判別器注意力模塊已經(jīng)引導(dǎo)生成器聚焦了一個(gè)域,那么生成器的注意力模塊則聚焦其它的域)。
3.1 生成器
在這里,我們只解釋Gs→tG_{s \rightarrow t}Gs→t?和 DtD_tDt?(見(jiàn)圖 1),反之亦然。
符號(hào)說(shuō)明:
x∈{Xs,Xt}x \in\left\{X_{s}, X_{t}\right\}x∈{Xs?,Xt?}:來(lái)自源域和目標(biāo)域的樣本;
Gs→tG_{s \rightarrow t}Gs→t?:包括一個(gè)編碼器EsE_sEs?,一個(gè)解碼器GtG_tGt?,和一個(gè)輔助分類器ηs\eta_sηs?;
ηs(x)\eta_s(x)ηs?(x):表示xxx來(lái)自XsX_sXs?的概率;
Esk(x)E_{s}^{k}(x)Esk?(x):編碼器的第 kkk 個(gè)激活映射(map);
Eskij(x)E_{s}^{k_{i j}}(x)Eskij??(x):在(i,j)(i, j)(i,j)上的值;
wskw_s^kwsk?:通過(guò)使用全局平均池化和全局最大池化訓(xùn)練輔助分類器以學(xué)習(xí)源域的第kkk 個(gè)特征圖的權(quán)重,例如:ηs(x)=σ(ΣkwskΣijEskij(x))\eta_{s}(x)=\sigma\left(\Sigma_{k} w_{s}^{k} \Sigma_{i j} E_{s}^{k_{i j}}(x)\right)ηs?(x)=σ(Σk?wsk?Σij?Eskij??(x));
利用 wskw_s^kwsk?,可以計(jì)算一組特定領(lǐng)域的注意力特征圖:
as(x)=ws?Es(x)={wsk?Esk(x)∣1≤k≤n}a_{s}(x)=w_{s} * E_{s}(x)=\left\{w_{s}^{k} * E_{s}^{k}(x) \mid 1 \leq k \leq n\right\}as?(x)=ws??Es?(x)={wsk??Esk?(x)∣1≤k≤n}。
nnn:編碼特征圖的數(shù)量。
AdaLIN?(a,γ,β)=γ?(ρ?aI^+(1?ρ)?aL^)+β,aI^=a?μIσI2+?,aL^=a?μLσL2+?ρ←clip?[0,1](ρ?τΔρ)(1)\begin{array}{c} \operatorname{AdaLIN}(a, \gamma, \beta)=\gamma \cdot\left(\rho \cdot \hat{a_{I}}+(1-\rho) \cdot \hat{a_{L}}\right)+\beta, \\ \hat{a_{I}}=\frac{a-\mu_{I}}{\sqrt{\sigma_{I}^{2}+\epsilon}}, \hat{a_{L}}=\frac{a-\mu_{L}}{\sqrt{\sigma_{L}^{2}+\epsilon}} \\ \rho \leftarrow \operatorname{clip}_{[0,1]}(\rho-\tau \Delta \rho) \end{array}\tag1 AdaLIN(a,γ,β)=γ?(ρ?aI?^?+(1?ρ)?aL?^?)+β,aI?^?=σI2?+??a?μI??,aL?^?=σL2?+??a?μL??ρ←clip[0,1]?(ρ?τΔρ)?(1)
公式(1)的符號(hào)說(shuō)明:
- γ\gammaγ和β\betaβ由注意力圖的全連接層動(dòng)態(tài)計(jì)算;
- μI\mu_IμI? , μL\mu_LμL? 和σI\sigma_IσI?, σL\sigma_LσL? 分別是通道方式、層方式均值和標(biāo)準(zhǔn)差;
- τ\tauτ為學(xué)習(xí)速率;
- ΔρΔ \rhoΔρ 表示優(yōu)化器確定的參數(shù)更新向量(如梯度);
- ρ\rhoρ的值被限制在[0,1][0,1][0,1]的范圍內(nèi),只需在參數(shù)更新步驟中設(shè)置界限即可;生成器調(diào)整該值,以便在實(shí)例規(guī)范化很重要的任務(wù)中ρ\rhoρ的值接近1,而在層歸一化(LN)很重要的任務(wù)中ρ\rhoρ的值接近0。在解碼器的殘差塊中,ρ\rhoρ的值初始化為1,在解碼器的上采樣塊中,ρ\rhoρ的值初始化為0。
公式(1)中最核心的部分是:
ρ?IN+(1?ρ)?LNaI^=a?μIσI2+?aL^=a?μLσL2+?(2)\begin{array}{c} \rho \cdot IN+(1-\rho) \cdot LN \\ \hat{a_{I}}=\frac{a-\mu_{I}}{\sqrt{\sigma_{I}^{2}+\epsilon}} \\ \hat{a_{L}}=\frac{a-\mu_{L}}{\sqrt{\sigma_{L}^{2}+\epsilon}} \\ \end{array}\tag2 ρ?IN+(1?ρ)?LNaI?^?=σI2?+??a?μI??aL?^?=σL2?+??a?μL???(2)
- 層歸一化(Layer Norm,LN):通道(channel)方向做歸一化,算CHW(通道、高、寬)的均值,主要對(duì)RNN作用明顯;更多的考慮輸入特征通道之間的相關(guān)性,LN比IN風(fēng)格轉(zhuǎn)換更徹底,但是語(yǔ)義信息保存不足;
- 實(shí)例歸一化(Instance Norm,IN):一個(gè)通道(channel)內(nèi)做歸一化,算H*W的均值,用在風(fēng)格化遷移;因?yàn)樵趫D像風(fēng)格化中,生成結(jié)果主要依賴于某個(gè)圖像實(shí)例,所以對(duì)整個(gè)batch歸一化不適合圖像風(fēng)格化中,因而對(duì)HW做歸一化。可以加速模型收斂,并且保持每個(gè)圖像實(shí)例之間的獨(dú)立;更多考慮單個(gè)特征通道的內(nèi)容,IN比LN更好的保存原圖像的語(yǔ)義信息,但是風(fēng)格轉(zhuǎn)換不徹底。
3.2 判別器
3.3 損失函數(shù)
模型包括四個(gè)損失函數(shù):
- 對(duì)抗損失:Llsgans→t=(Ex~Xt[(Dt(x))2]+Ex~Xs[(1?Dt(Gs→t(x)))2])L_{l s g a n}^{s \rightarrow t}=\left(\mathbb{E}_{x \sim X_{t}}\left[\left(D_{t}(x)\right)^{2}\right]+\mathbb{E}_{x \sim X_{s}}\left[\left(1-D_{t}\left(G_{s \rightarrow t}(x)\right)\right)^{2}\right]\right)Llsgans→t?=(Ex~Xt??[(Dt?(x))2]+Ex~Xs??[(1?Dt?(Gs→t?(x)))2]),保證風(fēng)格遷移圖像的分布與目標(biāo)圖像分布相匹配;
- 循環(huán)損失:Lcycle?s→t=Ex~Xs[∥x?Gt→s(Gs→t(x)))∥1]\left.L_{\text {cycle }}^{s \rightarrow t}=\mathrm{E}_{x \sim X_{s}}\left[\| x-G_{t \rightarrow s}\left(G_{s \rightarrow t}(x)\right)\right) \|_{1}\right]Lcycle?s→t?=Ex~Xs??[∥x?Gt→s?(Gs→t?(x)))∥1?],保證一個(gè)圖像x∈Xsx \in X_sx∈Xs?,在從XsX_sXs?到XtX_tXt?,XtX_tXt?到XsX_sXs?一系列轉(zhuǎn)化后,該圖像能成功的轉(zhuǎn)化回原始域;
- 一致性損失:Lidentity?s→t=Ex~Xt[∥x?Gs→t(x)∥1]L_{\text {identity }}^{s \rightarrow t}=\mathrm{E}_{x \sim X t}\left[\left\|x-G_{s \rightarrow t}(x)\right\|_{1}\right]Lidentity?s→t?=Ex~Xt?[∥x?Gs→t?(x)∥1?],保證輸入圖像與輸出圖像的顏色分布相似,給定一個(gè)圖像x∈Xtx \in X_tx∈Xt?,在使用Gs→tG_{s→t}Gs→t?翻譯之后,圖像不應(yīng)該改變;
- 分類激活映射損失:Lcams→t=?(Ex~Xs[log?(ηs(x))]+Ex~Xt[log?(1?ηs(x))]LcamDt=Ex~Xt[(ηDt(x))2]+Ex~Xs[(1?ηDt(Gs→t(x))2]\begin{array}{l} L_{c a m}^{s \rightarrow t}=-\left(\mathrm{E}_{x \sim X_{s}}\left[\log \left(\eta_{s}(x)\right)\right]+\mathrm{E}_{x \sim X_{t}}\left[\log \left(1-\eta_{s}(x)\right)\right]\right. \\ L_{c a m}^{D t}=\mathrm{E}_{x \sim X_{t}}\left[\left(\eta_{D t}(x)\right)^{2}\right]+\mathrm{E}_{x \sim X_{s}}\left[\left(1-\eta_{D t}\left(G_{s \rightarrow t}(x)\right)^{2}\right]\right. \end{array}Lcams→t?=?(Ex~Xs??[log(ηs?(x))]+Ex~Xt??[log(1?ηs?(x))]LcamDt?=Ex~Xt??[(ηDt?(x))2]+Ex~Xs??[(1?ηDt?(Gs→t?(x))2]?,輔助分類器ηsη_sηs?和ηDtη_{D_t}ηDt??帶來(lái)的損失。
最后,聯(lián)合訓(xùn)練編碼器、解碼器、判別器和輔助分類器以優(yōu)化最終目標(biāo)函數(shù):
min?Gs→t,Gt→s,ηs,ηtmax?Ds,Dt,ηDs,ηDtλ1Llsgan?+λ2Lcycle?+λ3Lidentity?+λ4Lcam?\min _{G_{s \rightarrow t}, G_{t \rightarrow s}, \eta_{s}, \eta_{t}} \max _{D_{s}, D_{t}, \eta_{D_{s}}, \eta_{D_{t}}} \lambda_{1} L_{\text {lsgan }}+\lambda_{2} L_{\text {cycle }}+\lambda_{3} L_{\text {identity }}+\lambda_{4} L_{\text {cam }} Gs→t?,Gt→s?,ηs?,ηt?min?Ds?,Dt?,ηDs??,ηDt??max?λ1?Llsgan??+λ2?Lcycle??+λ3?Lidentity??+λ4?Lcam??
其中λ1=1,λ2=10,λ3=10,λ4=1000\lambda_{1}=1, \lambda_{2}=10, \lambda_{3}=10, \lambda_{4}=1000λ1?=1,λ2?=10,λ3?=10,λ4?=1000, Llsgan?=Llsgan?s→t+Llsgan?t→s,Lcycle?=Lcycle?s→t+Lcycle?t→s,Lidentity?=Ldentity?s→t+Lidentity?t→s,Lcam?=Lcam?s→t+Lcam?t→sL_{\text {lsgan }}=L_{\text {lsgan }}^{s \rightarrow t}+L_{\text {lsgan }}^{t \rightarrow s}, L_{\text {cycle }}=L_{\text {cycle }}^{s \rightarrow t}+L_{\text {cycle }}^{t \rightarrow s}, L_{\text {identity }}=L_{\text {dentity }}^{s \rightarrow t}+L_{\text {identity }}^{t \rightarrow s}, L_{\text {cam }}=L_{\text {cam }}^{s \rightarrow t}+L_{\text {cam }}^{t \rightarrow s}Llsgan??=Llsgan?s→t?+Llsgan?t→s?,Lcycle??=Lcycle?s→t?+Lcycle?t→s?,Lidentity??=Ldentity?s→t?+Lidentity?t→s?,Lcam??=Lcam?s→t?+Lcam?t→s?。
總結(jié)
以上是生活随笔為你收集整理的风格迁移--U-GAT-IT模型(ICLR 2020)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: vb.net调用oracle存储过程,v
- 下一篇: From AlphaGo Zero to