神经网络优化中的Weight Averaging
?PaperWeekly 原創 ·?作者|張子遜
研究方向|神經網絡剪枝、NAS
在神經網絡優化的研究中,有研究改進優化器本身的(例如學習率衰減策略、一系列 Adam 改進等等),也有不少是改進 normalization 之類的技術(例如 Weight Decay、BN、GN 等等)來提高優化器的性能和穩定性。除此之外,還有一個比較常見的技術就是 Weight Averaging,也就是字面意思對網絡的權重進行平均,這也是一個不錯的提高優化器性能/穩定性的方式。
Stochastic Weight Averaging (SWA)
在神經網絡的優化中,有一個比較公認的問題就是 train loss 和 test loss 優化曲面不一致。導致這個問題可能的原因有很多,(以下是個人推測)可能是 train 和 test 數據本身存在分布不一致的問題、也可能是因為訓練的時候在 train loss 中加入了一系列正則化等等。由于這個不一致的問題,就會導致優化出來的網絡 generalization performance 可能會不好。?
本文提出了一個比較簡單直接的方式來解決這個問題,在優化的末期取 k 個優化軌跡上的 checkpoints,平均他們的權重,得到最終的網絡權重,這樣就會使得最終的權重位于 flat 曲面更中心的位置。這個方法也被集成到了 PyTorch 1.6 之后的版本中。
本文的實驗分析部分也給出了詳細的分析和結果來驗證這樣的一種方法是有效的。首先是分析的部分,本文通過可視化繪制了通過 SWA 和 SGD 在 loss landscape 上的收斂位置,以及 SGD 優化得到的權重對應的 loss 和 error 相比 SWA 得到的權重的距離,如下圖所示。
從圖上可以看到幾個有趣的現象:首先,train loss 和 test loss 的 landscape 之間確實存在偏移;其次,SGD 更傾向于收斂到 flat 區域的邊緣。比較直觀的一種猜想就是,利用這樣的性質,SWA 可以通過平均 flat 區域邊緣的一些 checkpoints,來使得最終的收斂位置更靠近中心的位置。在實際實驗中,也顯示了類似的結果,經過 SWA 平均之后,網絡的測試準確率都有不同程度的提升。
Stochastic Weight Averaging in Parallel
這篇文章是對上面的 SWA 在并行優化中的一個應用。在并行優化神經網絡的過程中,batch size 的增加可以使 SGD 的梯度計算更精確,因此可以使用更大的 lr 進行優化,同時也可以縮短優化的次數。但是這樣往往也會使得優化出來的 generalization performance 更差,因此就需要引入一些額外的技術來避免這個問題。?
本文則是講優化過程分成三個階段。在前期利用 large batch size 的優勢讓網絡的 loss 快速收斂到一個相對不錯的平坦區。在第二階段每一個節點獨立的用 mini batch 來更新模型。最后利用 SWA 來對這些模型進行平均,改善 large batch 帶來的 generalization 問題。
本文在實驗分析中,同樣發現了類似 SWA 的現象就是 train loss 的曲面與 test loss 的曲面不一致。本文利用可視化方法繪制了一張 CIFAR-10 上的 loss landscape 如下圖所示。
從圖上可見,train loss 的 flat 區域要比 test loss 大得多;同時 SGD 更傾向于停在 flat 區域的邊緣。而經過 SWA 之后,平均之后的模型有更大的概率落在相對中心的位置上。
Lookahead Optimizer
前面提到的 SWA 在優化上,并沒有改變原本優化器更新的梯度,只在結束之后選取一部分 checkpoints 進行 weight averaging 得到最終的權重。而這篇文章則是在更新過程中,利用指數移動平均的方式來計算梯度更新權重。?
本文提出了一種權重的更新策略,每一個 step 的優化中維護兩組權重,第一組稱為 fast weight,就是常規優化器更新得到的權重,第二組稱為 slow weight,是利用 fast weight 得到的權重。之所以稱為 fast/slow,是由于二者的更新頻率不同,先用 fast 更新 k 步,然后根據得到的 fast weight 更新 slow weight 1 步作為這一個優化 step 的結果,依次循環進行。
在實驗中,本文也通過大量的實驗,驗證了 Lookahead 優化器在前期比 SGD 優化的更快。此外,在實驗分析中,本文也發現了一個有趣的現象,就是每一輪從 slow weight 開始 fast weight(SGD/Adam 更新)反而讓 loss 上升了,而經過 slow weight 移動平均之后 loss 又恢復了下降。
Filter Grafting
這篇論文雖然在出發點上與前文并不相同,但是實際的方法也可以看作是一種 weight average,因此也總結在這里了。?
神經網絡經常會出現冗余的問題,一般的方法都是剪枝的方式來消除冗余部分,而也有一些其他的方法則是重新利用冗余部分來提高網絡的性能/泛化性,例如 Dense sparse dense training、RePr 等等。本文也是同樣的出發點,希望通過引入外部信息來改善冗余問題。?
本文提出了一種利用權重的熵來評估網絡中 filter 所包含的信息量,在優化中,同時優化兩個相同的網絡,采用不同的超參來進行優化,在優化過程中對 filter 進行加權平均實現對信息量的補足。
在加權平均時,為了簡潔,不采用針對特定位置的 filter 進行加權,而是根據一層的熵大小來對整個一層的參數進行平均,加權所用的 alpha 則是根據兩個網絡這一層熵的大小自適應決定的。更進一步,則是可以將 2 個網絡相互加權,拓展成多個網絡循環加權,如下圖所示。
在實驗中,本文除了對這種 grafting 策略進行了性能測試,也對其他的一些細節進行了分析:(1) 不同的信息來源對提升的影響(網絡自身、噪聲、不同網絡);(2) 不同的信息量評估方式的影響(L1 norm、熵)。最終得出文中提出的多個網絡基于熵的雜交策略是最優的。同時也對雜交訓練得到的網絡的冗余量(權重 L1 norm 統計)和網絡最終熵之和進行了分析。
參考文獻
[1]?Izmailov P, Podoprikhin D, Garipov T, et al. Averaging weights leads to wider optima and better generalization[J]. arXiv preprint arXiv:1803.05407, 2018.?
[2] Gupta V, Serrano S A, DeCoste D. Stochastic Weight Averaging in Parallel: Large-Batch Training that Generalizes Well[J]. arXiv preprint arXiv:2001.02312, 2020.?
[3] Zhang M, Lucas J, Ba J, et al. Lookahead optimizer: k steps forward, 1 step back[J]. Advances in Neural Information Processing Systems, 2019, 32: 9597-9608.?
[4] Meng F, Cheng H, Li K, et al. Filter grafting for deep neural networks[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020: 6599-6607.
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的神经网络优化中的Weight Averaging的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 强化学习中的调参经验与编程技巧(on p
- 下一篇: 车辆损失险的赔偿范围