TensorFlow的batch_normalization
批量標準化(batch normalization簡稱BN)主要是為了克服當神經網絡層數加深而導致難以訓練而誕生的。當深度神經網絡隨著網絡深度加深,訓練起來會越來越困難,收斂速度會很慢,還會產生梯度消失問題(vanishing gradient problem)。
在統計機器學習領域中有一個ICS(Internal Covariate Shift)理論:源域(source domain)和目標域(target domain)的數據分布是一致的。也就是訓練數據和測試數據滿足相同的分布,這是通過訓練數據獲得的模型在測試數據上有一個好的效果的保證。
Covariate Shift是指訓練數據的樣本和測試數據的樣本分布不一致時,訓練獲取的模型無法很好的泛化。它是分布不一致假設之下的一個分支問題,也就是指源域和目標域的條件概率是一致的,但是其邊緣概率不同。對于神經網絡而言,神經網絡的各層輸出,在經過了層內操作后,各層輸出分布會隨著輸入分布的變化而變化,而且差異會隨著網絡的深度增加而加大,但是每一層隨指向的樣本標記是不會改變的。
解決Covariate Shift問題可以通過對訓練樣本和測試樣本的比例對訓練樣本做一個矯正,通過批量標準化來標準化某些層或所有層的輸入,從而固定每層輸入信號的均值與方差。
一、批量標準化的實現
批量標準化是在激活函數之前,對z=wx+b做標準化,使得輸出結果滿足標準的正態分布,即均值為0,方差為1。讓每一層的輸入有一個穩定的分布便于網絡的訓練。
二、批量標準化的優點
1、加大探索的步長,加快模型收斂的速度
2、更容易跳出局部最小值
3、破壞原來的數據分布,在一定程度上可以緩解過擬合。
當遇到神經網絡收斂速度很慢或梯度爆炸等無法訓練的情況時,可以嘗試使用批量標準化來解決問題。
三、TensorFlow的批量標準化實例
1、tf.nn.moments(x,axes,shift=None,name=None,keep_dims=False)
函數介紹:計算x的均值和方差
參數介紹:
- x:需要計算均值和方差的tensor
- axes:指定求解x某個維度上的均值和方差,如果x是一維tensor,則axes=[0]
- name:用于計算均值和方差操作的名稱
- keep_dims:是否產生與輸入相同相同維度的結果
2、tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)
函數介紹:計算batch normalization
參數介紹:
- x:輸入的tensor,具有任意的維度
- mean:輸入tensor的均值
- variance:輸入tensor的方差
- offset:偏置tensor,初始化為1
- scale:比例tensor,初始化為0
- variance_epsilon:一個接近于0的數,避免除以0
總結
以上是生活随笔為你收集整理的TensorFlow的batch_normalization的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: sorce insight 4.0 编辑
- 下一篇: 域内双向NAT技术