【StyleGAN代码学习】StyleGAN模型架构
完整StyleGAN筆記:http://www.gwylab.com/pdf/Note_StyleGAN.pdf
基于StyleGAN的一個好玩的網站:www.seeprettyface.com
—————————————————————————————————
??
第二章 StyleGAN代碼解讀(上)
??這一章將對StyleGAN的代碼進行非常細致的分析和解讀。一方面有助于對StyleGAN的架構和原理有更深的認識,另一方面是覺得AdaIN的思想很有價值,希望把它寫代碼的技巧學習下來,以后應該在GANs中會有很多能用得上的地方(其它paper里挺多出現了AdaIN的地方)。含有中文注釋的代碼可以在這里獲得。
2.1 StyleGAN代碼架構總覽
??????圖2.1 StyleGAN官方代碼架構
??如圖2.1所示,StyleGAN代碼的封裝與解耦做的非常細致,可見作者的代碼功底是非常扎實的。簡單來說,在dnnlib文件夾下封裝了日志提交工具、tensorflow環境配置與網絡處理工具以及一些雜項實用類函數,這個文件夾盡量不要去動;在metrics文件夾下定義了許多指標計算方法,包括FID、LS、PPL指標以及一些GANs的通用指標計算方法;而training文件夾是需重點關注的部分,里面包含了數據處理、模型架構、損失函數和訓練過程等基于StyleGAN的核心內容,在接下來的筆記中也會重點對這一部分進行細致講解;最后,在主目錄下,有一些全局配置、功能展示和運行接口的代碼,其中train.py值得細讀一下,它是訓練StyleGAN網絡的主要切入點。
??在接下來的筆記中,將從三個部分解讀StyleGAN的代碼,分別是:模型架構、損失函數和訓練過程。至于其它部分的代碼,由于我并不是特別關注,就不再贅述了。
2.2 網絡架構代碼解讀
??StyleGAN的網絡架構全都寫在training/networks_stylegan.py下,主要包括四個組成部分(代碼302行-659行):G_style(),G_mapping(),G_synthesis()和D_basic()。
??如上圖所示,G_style表示整個生成器的網絡架構,它由兩個子網絡組成,分別是映射網絡G_mapping和合成網絡G_synthesis;然后D_basic表示整個判別器的網絡架構,它沿用了ProGAN中的模型設計。
2.2.1 G_style網絡
??G_style網絡位于代碼302-379行。在G_style中定義的組件包括:參數驗證->設置子網絡->設置變量->計算映射網絡輸出->更新W的移動平均值->執行樣式混合正則化->應用截斷技巧->計算合成網絡輸出。其中設置子網絡就是調用構建G_mapping和G_synthesis的過程,兩個子網絡的定義將在下兩節介紹。
??· G_style輸入參數(line303-315)
??輸入參數包括512維的Z碼向量和條件標簽(在FFHQ數據集上沒有使用標簽),和一些可選的參數,包括截斷參數、計算移動平均值W時的參數、樣式混合參數、網絡狀態參數和子網絡的參數們等。
??· 參數驗證(line318-330)
??對輸入參數進行驗證,主要是對網絡狀態和其對截斷率、W平均值衰減率和樣式混合概率的關系之間進行驗證。
??· 設置子網絡(line333-338)
??直接使用tflib.Network()類(充當參數化網絡構建功能的便捷包裝,提供多種實用方法并方便地訪問輸入/輸出/權重)創建兩個子網絡,子網絡的內容在后面的函數(func_name = G_synthesis或func_name = G_mapping)中被定義。
??· 設置變量(line341-342)
??設置兩個變量lod_in和dlatent_avg。前者決定當前處在第幾層分辨率,即lod=resolution_log2–res(其中res表示當前層對應的分辨率級別(2-10));后者決定截斷操作的基準,即生成人臉的dlatent碼的平均值。
??· 計算映射網絡輸出(line345)
??得到映射網絡的輸出,即中間向量W’。
??· 更新W的移動平均值(line348-353)
??把batch的dlatent平均值朝著總dlatent平均值以dlatent_avg_beta步幅靠近,作為新的人臉dlatent平均值,即dlatent_avg。
??· 執行樣式混合正則化(line356-366)
??樣式混合正則化其實很好理解,就是隨機創建一個新的潛碼,這個潛碼以一定概率與原始潛碼交換某一部分,對于交換后的混合潛碼,其生成的圖片也要能夠逼真,這就是樣式混合正則化的實現。
??· 應用截斷技巧(line369-374)
??截斷是指,用平均臉dlatent_avg朝著當前臉dlatents以coefs步幅靠近,得到的結果就是截斷的dlatents。
??· 計算合成網絡輸出(line377-379)
??將截斷的dlatents傳給G_synthesis網絡進行合成,得到的結果就是整個生成網絡G_style的輸出結果。
?
2.2.2 G_mapping網絡
??G_mapping網絡位于代碼384-435行。如上圖所示,G_mapping網絡實現了從初始生成碼到中間向量的映射過程。在G_mapping中定義的組件包括:輸入->連接標簽->歸一化潛碼->映射層->廣播->輸出。其中映射層由8個全連接層組成。
?
??· G_mapping輸入參數(line385-398)
??輸入參數包括512維的Z碼向量和條件標簽(在FFHQ數據集上沒有使用標簽),和一些可選的參數,包括初始向量Z參數、中間向量W參數、映射層設置、激活函數設置、學習率設置以及歸一化設置等。
??· 網絡輸入(line403-407)
??處理好latent的大小和格式后,其值賦給x,即用x標識網絡的輸入。
??· 連接標簽(line410-414)
??原始StyleGAN是無標簽訓練集,這部分不會被調用。
??· 歸一化潛碼(line417-418)
??pixel_norm()(line239-242)
??逐像素歸一化的實現方式為:,其中?=10^-8。
??為何要使用pixel_norm呢? Pixel norm,它是local response normalization的變種,具有避免生成器梯度爆炸的作用。Pixel norm沿著channel維度做歸一化(axis=1),這樣歸一化的一個好處在于,feature map的每個位置都具有單位長度。這個歸一化策略與作者設計的Generator輸出有較大關系,注意到Generator的輸出層并沒有Tanh或者Sigmoid激活函數。
??· 映射層(line421-426)
??構建了mapping_layers層映射層,每個映射層有三個組成部分:全連接層dense()、偏置函數apply_bias()和激活函數act()。
??1)全連接層dense()(line 154-159)
??dense()函數中首先將輸出全部展平以計算輸出維度,然后調用get_weight()創建全連接層w,最后返回x與w的矩陣相乘的結果,作為dense()層的輸出。
??get_weight()(line 135-149)
??get_weight()函數是用來創建卷積層或完全連接層,且獲取權重張量的函數。
??值得注意的是,get_weight()采用了He的初始化方法。He的初始化方法能夠確保網絡初始化的時候,隨機初始化的參數不會大幅度地改變輸入信號的強度。StyleGAN中不僅限初始狀態scale而且也是實時scale。
??2)添加偏置apply_bias()(line 213-218)
??對給定的激活張量施加偏差。
??3)激活函數act()(line 400)
??激活函數采用mapping_nonlinearity的值,StyleGAN中選用’lrelu’,且增益值為√2。
??注意這兒的gain是一個增益值,增益值是指的非線性函數穩態時輸入幅度與輸出幅度的比值,通常被用來乘在激活函數之后使激活函數更加穩定。常見的增益值包括:Sigmoid推薦的增益值是1,Relu推薦的增益值是√2,LeakyRelu推薦的增益值是。
??· 廣播(line429-431)
??這兒只定義了簡單的復制擴充,廣播前x的維度是(?,512),廣播后x的維度是(?,18,512)。
??· 輸出(line 434-435)
??廣播后的中間向量,就是G_mapping網絡的最終輸出。
2.2.3 G_synthesis網絡
??G_synthesis網絡位于代碼441-560行。如上圖所示,G_synthesis網絡實現了從廣播得到的中間向量到生成圖片的合成過程。在G_synthesis中定義的組件包括:預處理->主要輸入->噪音輸入->★每層層末的調制功能->早期層(4*4)結構->剩余層的block塊->★網絡增長變換過程->輸出。其中,每層層末的調制功能是指的在卷積之后,融入噪音與樣式控制的過程(上圖的⊕與AdaIN過程);網絡增長變換過程是指的在訓練時期,合成網絡的架構動態增長,以生成更高分辨率圖片的過程。上述兩個內容是重點值得學習的部分。
?
??· G_synthesis輸入參數(line441-462)
??輸入參數包括512維的中間向量(W)和輸出圖片的分辨率及通道,和一些可選的參數,包括各層特征圖的設置、樣式/網絡起始值/噪音設置、激活函數設置、數據處理設置以及網絡增長架構的設置等。
??· 預處理(line464-474)
??預處理部分除了進一步細化網絡配置以外,還定義了兩個函數——nf()返回在第stage層中特征圖的數量;blur()對圖片進行濾波模糊操作,有利于降噪,其中blur的函數實現方式為blur2d()。
??blur2d ()(line 96-106)
??在blur2d()里定義了模糊的返回函數為_blur2d(),同時blur2d()的一階導和二階導也被定義了出來,都是直接使用_blur2d()函數作為近似。
??_blur2d ()(line 22-49)
??_blur2d ()的實現主要有兩個部分,第一個部分是對于卷積核的處理,包括維度的規范和歸一化處理;第二個部分是模糊的實現,即用卷積核對x實行depthwise_conv2d卷積。
??注意:depthwise_conv2d與普通的卷積有些不同。普通的卷積對卷積核每一個out_channel的兩個通道分別和輸入的兩個通道做卷積相加,得到feature map的一個channel,而depthwise_conv2d卷積對每一個對應的in_channel,分別卷積生成兩個out_channel,所以獲得的feature map的通道數量可以用in_channel * channel_multiplier來表達,這個channel_multiplier,就可以理解為卷積核的第四維。參見博客:https://blog.csdn.net/mao_xiao_feng/article/details/78003476。
??· 主要輸入(line 477-479)
??主要輸入除了dlatents_in之外,還有一個lod_in參數。lod_in是一個指定當前輸入分辨率級別的參數,規定lod = resolution_log2 – res。lod_in在遞歸構建模型架構的部分中被使用。
??· 創建噪音(line 482-487)
??最初創建噪音時,只是依據對應層的分辨率創建對應的shape,然后隨機初始化即為一個噪音。
??· ★層末調制(含AdaIN,line490-501)
??層末調制,是在每個block的卷積之后對特征的處理,包含6種(可選)內容:應用噪音apply_noise()、應用偏置apply_bias()、應用激活函數act()、逐像素歸一化pixel_norm()、實例歸一化instance_norm()和樣式調制(AdaIN)style_mod()。其中apply_bias()、act()與pixel_norm()在前文中已提及過,下面將不再贅述。
??1)apply_noise()(line 270-278)
??應用噪音,直接將噪音加在特征x上就行了,注意按channel疊加。
??2)instance_norm()(line 247-256)
??實例歸一化是一個在生成模型中應用非常廣泛的歸一化方式,它的主要特點是僅對特征的HW(高和寬)維度做歸一化,對圖像的風格影響明顯。
??3)★style_mod ()(line 261-265)
??樣式控制(AdaIN)的代碼只有3行。第1行是仿射變化A,它通過全連接層將dlatent擴大一倍;第2行將擴大后的dlatent轉換為放縮因子y_(s,i)和偏置因子y_(b,i);第3行是將這兩個因子對卷積后的x實施自適應實例歸一化。
??· 早期層結構(line 504-514)
??由于StyleGAN的網絡結構隨訓練進行是動態變化的,所以代碼中定義了訓練最開始時的網絡結構,即4*4分辨率的生成網絡。StyleGAN的生成起點選用維度為(1,512,4,4)的常量,通過一個卷積層(conv2d)得到了通道數為nf(1)(即512維)的特征圖。
??· conv2d(line 164-168)
??conv2d通過簡單的卷積實現,將x的通道數由x.shape[1]變為fmaps,而x的大小不變。
??· 剩余層的block塊(line 517-527)
??StyleGAN將剩余分辨率的網絡層封裝成了block函數,方便在訓練過程中依據輸入的res(由lod_in計算出來)實時構建及調整網絡的架構。其中每個block都包括了一個上采樣層(upscaled_conv2d)和一個卷積層,上采樣層后置濾波處理與層末調制,卷積層后置層末調制。另外在訓練過程中網絡需實時輸出圖片,StyleGAN中定義了torgb()函數,負責將對應分辨率的特征圖轉換為RGB圖像。
??· upscaled_conv2d()(line 174-191)
??upscaled_conv2d利用tf.nn.conv2d_transpose反卷積操作實現將特征圖放大一倍。其中一個值得注意的操作是,其卷積核被輕微平移了四次并對自身做了疊加,這樣或許對于提取特征有幫助,但我查閱不到相關資料證明這一點。
??· ★網絡增長變換過程(line 530-556)
??StyleGAN的生成網絡需具備動態變換的能力,代碼中定義了三種結構組合方式,分別是:固定結構、線性結構與遞歸結構。
??1)固定結構(line 530-533)
??固定結構構建了直達1024*1024分辨率的生成器網絡,簡單高效但不支持漸進式增長。
??2)線性結構(line 536-544)
??線性結構構建了upscale2d()的上采樣層,能將當前分辨率放大一倍。另外在不同分辨率間變換時,線性結構采用了含大小值裁剪的線性插值,實現了不同分辨率下的平滑過渡。
??upscale2d()(line 108-118)
??在upscale2d()里定義了上采樣的返回函數為_upscale2d(),同時upscale2d()的一階導和二階導也被定義了出來,都是直接使用_upscale2d()函數作為近似。
??_upscale2d()(line 51-68)
??_upscale2d()的實現使用了tf.tile()的騷操作,其實很簡單,就是復制擴展了像素值。總而言之,線性結構的實現簡單但效率低下。
??3)遞歸結構(line 547-556)
??遞歸結構下定義了遞歸函數grow(),使得只需要調用一次grow()就能夠實現所有分辨率層的構建。它的實現邏輯主要是:比較lod_in和lod的關系——當lod_in超過lod時,在同一層級上實現分辨率的線性擴增;當lod_in小于lod且不是最后一層時,跳轉到下一級的分辨率層上。
??· 輸出(line 558-559)
??網絡的最終輸出為之前結構中得到的images_out。
??· G_style架構總覽
??最后,通過一張G_style的完整網絡架構圖,讓我們對各個層的名稱、參數量、輸入維度和輸出維度有更具體的理解。
2.2.4 D_basic網絡
??D_basic網絡位于代碼566-661行。如上圖所示,D_basic網絡實現區分合成圖片與真實圖片的功能,沿用了ProGAN判別器的架構,基本上是生成器反過來的樣子。在D_basic中定義的組件包括:預處理->構建block塊->★網絡增長變換過程->標簽計算->輸出。其中,網絡增長變換過程是指的在訓練時期,隨著生成圖片的分辨率提升,判別網絡的架構動態增長的過程。另外,標簽計算是指在訓練集使用了含標簽數據時,會將標簽值與判別分數的乘積作為最終判別網絡的輸出值。
??· D_basic輸入參數(line565-582)
??輸入參數包括圖片、標簽以及這二者的相關配置,和一些可選的參數,包括各層特征圖的設置、激活函數設置、小批量標準偏差層設置、數據處理設置以及網絡增長架構的設置等。
??· 預處理(line584-596)
??預處理主要包括細化網絡配置、輸入數據的處理以及輸出數據的定義,包含內容與G_synthesis中的預處理過程類似。
??· 構建block塊(line 599-618)
??在訓練過程中網絡需實時處理圖片,StyleGAN中定義了fromrgb()函數,負責將對應分辨率的RGB圖像轉換為特征圖。
??在block函數中,當分辨率不低于88時,一個block包含一個卷積層和一個下采樣層(conv2d_downscale2d);而當分辨率為最開始的44時,一個block包含一個小批量標準偏差層(minibatch_stddev_layer)、一個卷積層和兩個全連接層。
??conv2d_downscale2d()(line 193-208)
??conv2d_downscale2d利用tf.nn.conv2d卷積操作實現將特征圖縮小一倍。其中一個值得注意的操作是,其卷積核被輕微平移了四次并對自身做了疊加然后取平均值,這樣或許對于提取特征有幫助,但我查閱不到相關資料證明這一點。
??minibatch_stddev_layer()(line 283-296)
??在4*4分辨率的block中,構建了一個小批量標準偏差層,將特征圖標準化處理,這樣能讓判別網絡收斂得更快。
??· ★網絡增長變換過程(line 621-650)
??StyleGAN的判別網絡需具備動態變換的能力,代碼中定義了三種結構組合方式,分別是:固定結構、線性結構與遞歸結構。
??1)固定結構(line 621-625)
??固定結構構建了直達10241024分辨率的判別器網絡,簡單高效但不支持漸進式增長。
??2)線性結構(line 628-638)
??線性結構構建了downscale2d()的下采樣層,能將當前分辨率縮小一倍。另外在不同分辨率間變換時,線性結構采用了含大小值裁剪的線性插值,實現了不同分辨率下的平滑過渡。
??downscale2d()(line 120-130)
??在downscale2d()里定義了下采樣的返回函數為_downscale2d(),同時downscale2d()的一階導和二階導也被定義了出來,都是直接使用_downscale2d ()函數作為近似。
??_downscale2d()(line 70-90)
??_downscale2d()中,如果卷積核大小為22,則直接返回_blur2d()的結果;否則采用平均池化的方式實現下采樣。
??3)遞歸結構(line 641-650)
??遞歸結構下定義了遞歸函數grow(),使得只需要調用一次grow()就能夠實現所有分辨率層的構建。它的實現邏輯比較復雜,請參見代碼注釋;值得注意的是,構建判別網絡時是lod從大往小構建,所以遞歸的過程是與生成器相反的。
??· 標簽計算(line 653-655)
??如果使用了標簽的話,將標簽值與判別分數的乘積作為最終判別網絡的輸出值。
??· 輸出(line 657-659)
??網絡的最終輸出為scores_out。
??· D_basic架構總覽
??最后,通過一張D_basic的完整網絡架構圖,讓我們對各個層的名稱、參數量、輸入維度和輸出維度有更具體的理解。
總結
以上是生活随笔為你收集整理的【StyleGAN代码学习】StyleGAN模型架构的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 工业数据采集平台
- 下一篇: 【单片机】继电器控制