PyTorch基础(12)-- torch.nn.BatchNorm2d()方法
Batch Normanlization簡稱BN,也就是數據歸一化,對深度學習模型性能的提升有很大的幫助。BN的原理可以查閱我之前的一篇博客。白話詳細解讀(七)----- Batch Normalization。但為了該篇博客的完整性,在這里簡單介紹一下BN。
一、BN的原理
BN的基本思想其實相當直觀:因為深層神經網絡在做非線性變換前的激活輸入值(就是那個x=WU+B,U是輸入)隨著網絡深度加深或者在訓練過程中,其分布逐漸發生偏移或者變動,之所以訓練收斂慢,一般是整體分布逐漸往非線性函數的取值區間的上下限兩端靠近(對于Sigmoid函數來說,意味著激活輸入值WU+B是大的負值或正值),所以這導致反向傳播時低層神經網絡的梯度消失,這是訓練深層神經網絡收斂越來越慢的本質原因,而BN就是通過一定的規范化手段,把每層神經網絡任意神經元這個輸入值的分布強行拉回到均值為0方差為1的標準正態分布,其實就是把越來越偏的分布強制拉回比較標準的分布,這樣使得激活輸入值落在非線性函數對輸入比較敏感的區域,這樣輸入的小變化就會導致損失函數較大的變化,意思是這樣讓梯度變大,避免梯度消失問題產生,而且梯度變大意味著學習收斂速度快,能大大加快訓練速度。BN具體操作流程如下圖所示:
二、nn.BatchNorm2d()方法詳解
清楚了BN的原理之后,便可以很快速的理解這個方法了。
- 方法
-
Parameters
num_features:圖像的通道數,也即(N, C, H, W)中的C的值
eps:增加至分母上的一個很小的數,為了防止/0情況的發生
momentum:用來計算平均值和方差的值,默認值為0.1
affine:一個布爾類型的值,當設置為True的時候,該模型對affine參數具有可學習的能力,默認為True
track_running_stats:一個布爾類型的值,用于記錄均值和方差,當設置為True的時候,模型會跟蹤均值和方差,反之,不會跟蹤均值和方差
-
Shape
Input: (N, C, H, W)
Output: (N, C, H, W)
三、案例分析
import torch.nn as nn import torch if __name__ == '__main__':bn = nn.BatchNorm2d(3)ip = torch.randn(2, 3, 2, 2)print(ip)output = bn(ip)print(output)- 運行結果
總結
以上是生活随笔為你收集整理的PyTorch基础(12)-- torch.nn.BatchNorm2d()方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 拉面说:如何成为速食拉面独角兽
- 下一篇: PyTorch基础(13)-- torc