Batchsize不够大,如何发挥BN性能?探讨神经网络在小Batch下的训练方法
由于算力的限制,有時我們無法使用足夠大的batchsize,此時該如何使用BN呢?本文將介紹兩種在小batchsize也可以發揮BN性能的方法。
本文首發自極市平臺,作者 @皮特潘,轉載需獲授權。
前言
BN(Batch Normalization)幾乎是目前神經網絡的必選組件,但是使用BN有兩個前提要求:
不然的話,非但不能發揮BN的優勢,甚至會適得其反。但是由于算力的限制,有時我們無法使用足夠大的batchsize,此時該如何使用BN呢?本文介紹兩篇在小batchsize也可以發揮BN性能的方法。解決思路為:既然batchsize太小的情況下,無法保證當前minibatch收集到的數據和整體數據同分布。那么能否多收集幾個batch的數據進行統計呢?這兩篇工作分別分別是:
- BRN:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
- CBN:Cross-Iteration Batch Normalization
另外,本文也會給出代碼解析,幫助大家理解。
batchsize過小的場景
通常情況下,大家對CNN任務的研究一般為公開的數據集指標負責。分類任務為ImageNet數據集負責,其尺度為224X224。檢測任務為coco數據集負責,其尺度為640X640左右。分割任務一般為coco或PASCAL VOC數據集負責,后者的尺度大概在500X500左右。再加上例如resize的前處理操作,真正送入網絡的圖片的分辨率都不算太大。一般性能的GPU也很容易實現大的batchsize(例如大于32)的支持。
但是實際的項目中,經常遇到需要處理的圖片尺度過大的場景,例如我們使用500w像素甚至2000w像素的工業相機進行數據采集,500w的相機采集的圖片尺度就是2500X2000左右。而對于微小的缺陷檢測、高精度的關鍵點檢測或小物體的目標檢測等任務,我們一般不太想粗暴降低輸入圖片的分辨率,這樣違背了我們使用高分辨率相機的初衷,也可能導致丟失有用特征。在算力有限的情況下,我們的batchsize就無法設置太大,甚至只能為1或2。小的batchsize會帶來很多訓練上的問題,其中BN問題就是最突出的。雖然大batchsize訓練是一個共識,但是現實中可能無法具有充足的資源,因此我們需要一些處理手段。
BN回顧
首先Batch Normalization 中的Normalization被稱為標準化,通過將數據進行平和縮放拉到一個特定的分布。BN就是在batch維度上進行數據的標準化。BN的引入是用來解決 internal covariate shift 問題,即訓練迭代中網絡激活的分布的變化對網絡訓練帶來的破壞。BN通過在每次訓練迭代的時候,利用minibatch計算出的當前batch的均值和方差,進行標準化來緩解這個問題。雖然How Does Batch Normalization Help Optimization 這篇文章探究了BN其實和Internal Covariate Shift (ICS)問題關系不大,本文不深入討論,這個會在以后的文章中細說。
一般來說,BN有兩個優點:
- 降低對初始化、學習率等超參的敏感程度,因為每層的輸入被BN拉成相對穩定的分布,也能加速收斂過程。
- 應對梯度飽和和梯度彌散,主要是對于使用sigmoid和tanh的激活函數的網絡。
當然,BN的使用也有兩個前提:
- minibatch和全部數據同分布。因為訓練過程每個minibatch從整體數據中均勻采樣,不同分布的話minibatch的均值和方差和訓練樣本整體的均值和方差是會存在較大差異的,在測試的時候會嚴重影響精度。
- batchsize不能太小,否則效果會較差,論文給的一般性下限是32。
再來回顧一下BN的具體做法:
- 訓練的時候:使用當前batch統計的均值和方差對數據進行標準化,同時優化優化gamma和beta兩個參數。另外利用指數滑動平均收集全局的均值和方差。
- 測試的時候:使用訓練時收集全局均值和方差以及優化好的gamma和beta進行推理。
可以看出,要想BN真正work,就要保證訓練時當前batch的均值和方差逼近全部數據的均值和方差。
BRN
論文題目:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
論文地址: https://arxiv.org/pdf/1702.03275.pdf
代碼地址: https://github.com/ludvb/batchrenorm
核心解析:
本文的核心思想就是:訓練過程中,由于batchsize較小,當前minibatch統計到的均值和方差與全部數據有差異,那么就對當前的均值和方差進行修正。修正的方法主要是利用到通過滑動平均收集到的全局均值和標準差。看公式:
xi?μσ=xi?μBσB?r+d,where?r=σBσ,d=μB?μσ\frac{x_{i}-\mu}{\sigma}=\frac{x_{i}-\mu_{\mathcal{B}}}{\sigma_{\mathcal{B}}} \cdot r+d, \quad \text { where } r=\frac{\sigma_{\mathcal{B}}}{\sigma}, \quad d=\frac{\mu_{\mathcal{B}}-\mu}{\sigma} σxi??μ?=σB?xi??μB???r+d,?where?r=σσB??,d=σμB??μ?
上面公式中,i表示網絡的第i層。μ和σ表示網絡推理時的均值和標準差,也就是訓練過程通過滑動平均收集的到均值和方差。μB和σb表示當前訓練迭代過程中的實際統計到的均值和標準差。BN在小batch不work的根本原因就是這兩組參數存在較大的差異。通過r和d對訓練過程中數據進行線性變換,在該變化下,上公式左右兩端就嚴格相等了。其實標準的BN就是r=1,d=0的一種情況。對于某一個特定的minibatch,其中r和d可以看成是固定的,是直接計算出來的,不需要梯度優化的。
具體流程:
-
統計當前batch數據的均值和標注差,和標準BN做法一致。
-
根據當前batch的均值和標準差結合全局的均值和標準差利用上面的公式計算r和d;注意該運算是不參與梯度反向傳播的。另外,r和d需要增加一個限制,直接clip操作就好。
-
利用當前的均值和標準差對當前數據執行Normalization操作,利用上面計算得到的r和d對當前batch進行線性變換。
-
滑動平均收集全局均值和標注差。
測試過程和標準BN一樣。其實本質上,就是訓練的過程中使用全局的信息進行更新當前batch的數據。間接利用了全局的信息,而非當前這一個batch的信息。
實驗效果:
在較大的batchsize(32)的時候,與標準BN相比,不會丟失效果,訓練過程一如既往穩定高效。如下:
在小的batchsize(4)下, 本文做法依然接近batchsize為32的時候,可見在小batchsize下是work的。
代碼解析:
def forward(self, x):if x.dim() > 2:x = x.transpose(1, -1)if self.training: # 訓練過程dims = [i for i in range(x.dim() - 1)batch_mean = x.mean(dims) # 計算均值batch_std = x.std(dims, unbiased=False) + self.eps # 計算標準差# 按照公式計算r和dr = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(1 / self.rmax, self.rmax)d = ((batch_mean.detach() - self.running_mean.view_as(batch_mean))/ self.running_std.view_as(batch_std)).clamp_(-self.dmax, self.dmax)# 對當前數據進行標準化和線性變換x = (x - batch_mean) / batch_std * r + d# 滑動平均收集全局均值和標注差self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)self.running_std += self.momentum * (batch_std.detach() - self.running_std)self.num_batches_tracked += 1else: # 測試過程x = (x - self.running_mean) / self.running_stdreturn xCBN
論文題目:Cross-Iteration Batch Normalization
論文地址:https://arxiv.org/abs/2002.05712
代碼地址:https://github.com/Howal/Cross-iterationBatchNorm
本文認為BRN的問題在于它使用的全局均值和標準差不是當前網絡權重下獲取的,因此不是exactly正確的,所以batchsize再小一點,例如為1或2時就不太work了。本文使用泰勒多項式逼近原理來修正當前的均值和標準差,同樣也是間接利用了全局的均值和方差信息。簡述就是:當前batch的均值和方差來自之前的K次迭代均值和方差的平均,由于網絡權重一直在更新,所以不能直接粗暴求平均。本文而是利用泰勒公式估計前面的迭代在當前權重下的數值。
泰勒公式:
泰勒公式是一個用函數在某點的信息描述其附近取值的公式。如果函數滿足一定的條件,泰勒公式可以用函數在某一點的各階導數值做系數構建一個多項式來近似表達這個函數。教科書介紹如下:
核心解析:
本文做法,由于網絡一般使用SGD更新權重,因此網絡權重的變化是平滑的,所以適用泰勒公式。如下,t為訓練過程中當前迭代時刻,t-τ為t時刻向前τ時刻。θ為網絡權重,權重下標代表該權重的時刻。μ為當前minibatch均值,v為當強minibatch平方的均值,是為了計算標準差。因此直接套用泰勒公式得到:
μt?τ(θt)=μt?τ(θt?τ)+?μt?τ(θt?τ)?θt?τ(θt?θt?τ)+O(∥θt?θt?τ∥2)(5)\begin{aligned} \mu_{t-\tau}\left(\theta_{t}\right)=& \mu_{t-\tau}\left(\theta_{t-\tau}\right)+\frac{\partial \mu_{t-\tau}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}}\left(\theta_{t}-\theta_{t-\tau}\right) \\ &+\mathbf{O}\left(\left\|\theta_{t}-\theta_{t-\tau}\right\|^{2}\right) \end{aligned}\tag{5} μt?τ?(θt?)=?μt?τ?(θt?τ?)+?θt?τ??μt?τ?(θt?τ?)?(θt??θt?τ?)+O(∥θt??θt?τ?∥2)?(5)
νt?τ(θt)=νt?τ(θt?τ)+?νt?τ(θt?τ)?θt?τ(θt?θt?τ)+O(∥θt?θt?τ∥2)(6)\begin{aligned} \nu_{t-\tau}\left(\theta_{t}\right)=& \nu_{t-\tau}\left(\theta_{t-\tau}\right)+\frac{\partial \nu_{t-\tau}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}}\left(\theta_{t}-\theta_{t-\tau}\right) \\ &+\mathbf{O}\left(\left\|\theta_{t}-\theta_{t-\tau}\right\|^{2}\right) \end{aligned}\tag{6} νt?τ?(θt?)=?νt?τ?(θt?τ?)+?θt?τ??νt?τ?(θt?τ?)?(θt??θt?τ?)+O(∥θt??θt?τ?∥2)?(6)
上面這兩個公式就是為了估計在t-τ時刻,t時刻的權重下的均值和方差的參數估計。BRN可以看作沒有進行該方法估計,使用的依然是t-τ時刻權重的參數估計。其中O為高階項,因為該式主要由一階項控制,因此高階項目可以忽略。上面的公式還要進一步簡化,主要是偏導項的求法。假設當前層為l,實際上?μ/ ?θ 和 ?ν/?θ依賴與所有l層之前層的權重,求導計算量極大。不過經驗觀察到,l層之前層的偏數下降很快,因此可以忽略掉,僅僅計算當前層的權重偏導。
因此化簡為如下,可以看出,求偏導的部分,只考慮對當前層的偏導數,注意上標l表示網絡層的意思。至此,之前時刻在當前權重下的均值和方差已經估計出來了。
μt?τl(θt)≈μt?τl(θt?τ)+?μt?τl(θt?τ)?θt?τl(θtl?θt?τl)(7)\mu_{t-\tau}^{l}\left(\theta_{t}\right) \approx \mu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)+\frac{\partial \mu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}^{l}}\left(\theta_{t}^{l}-\theta_{t-\tau}^{l}\right)\tag{7} μt?τl?(θt?)≈μt?τl?(θt?τ?)+?θt?τl??μt?τl?(θt?τ?)?(θtl??θt?τl?)(7)
νt?τl(θt)≈νt?τl(θt?τ)+?νt?τl(θt?τ)?θt?τl(θtl?θt?τl)(8)\nu_{t-\tau}^{l}\left(\theta_{t}\right) \approx \nu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)+\frac{\partial \nu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}^{l}}\left(\theta_{t}^{l}-\theta_{t-\tau}^{l}\right)\tag{8} νt?τl?(θt?)≈νt?τl?(θt?τ?)+?θt?τl??νt?τl?(θt?τ?)?(θtl??θt?τl?)(8)
下面穿插代碼解析整個計算過程。
首先是統計計算當前batch的數據,和標準BN沒有差別。代碼為:
cur_mu = y.mean(dim=1) # 當前層的均值 cur_meanx2 = torch.pow(y, 2).mean(dim=1) # 當前值平方的均值,計算標準差使用 cur_sigma2 = y.var(dim=1) # 當前值的方差對當前網絡層求偏導,直接使用torch的內置函數。代碼:
# 注意 grad_outputs = self.ones : 不同值的梯度對結果影響程度不同,類似torch.sum()的作用。 dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0] dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]使用公式(7)和(8)繼續下面的計算,也就是向前累計K次估計數值,更新到當前batch的均值和方差的計算上,這里引入了一個超參就是k的大小,它表示當前的迭代向后回溯到多長的步長的迭代。實驗探究k=8是一個比較折中的選擇。k=1的時候,RBN退化成了原始的BN:
μˉt,kl(θt)=1k∑τ=0k?1μt?τl(θt)(9)\bar{\mu}_{t, k}^{l}\left(\theta_{t}\right)=\frac{1}{k} \sum_{\tau=0}^{k-1} \mu_{t-\tau}^{l}\left(\theta_{t}\right)\tag{9} μˉ?t,kl?(θt?)=k1?τ=0∑k?1?μt?τl?(θt?)(9)
νˉt,kl(θt)=1k∑τ=0k?1max?[νt?τl(θt),μt?τl(θt)2](10)\bar{\nu}_{t, k}^{l}\left(\theta_{t}\right)=\frac{1}{k} \sum_{\tau=0}^{k-1} \max \left[\nu_{t-\tau}^{l}\left(\theta_{t}\right), \mu_{t-\tau}^{l}\left(\theta_{t}\right)^{2}\right]\tag{10} νˉt,kl?(θt?)=k1?τ=0∑k?1?max[νt?τl?(θt?),μt?τl?(θt?)2](10)
σˉt,kl(θt)=νˉt,kl(θt)?μˉt,kl(θt)2(11)\bar{\sigma}_{t, k}^{l}\left(\theta_{t}\right)=\sqrt{\bar{\nu}_{t, k}^{l}\left(\theta_{t}\right)-\bar{\mu}_{t, k}^{l}\left(\theta_{t}\right)^{2}}\tag{11} σˉt,kl?(θt?)=νˉt,kl?(θt?)?μˉ?t,kl?(θt?)2?(11)
代碼如下,其中這里的self.pre_mu, self.pre_dmudw, self.pre_weight是前面每次迭代收集到了窗口k大小的數值,分別代表均值、均值對權重的偏導、權重。self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight同理,是對應平方均值的。
# 利用泰勒公式估計 mu_all = torch.stack \([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])meanx2_all = torch.stack \([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])上面所說的變量收集迭代過程如下:
# 動態維護buffer_num長度的均值、均值平方、偏導、權重 self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)] self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)] self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)] self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)] tmp_weight = torch.zeros_like(weight.data) tmp_weight.copy_(weight.data) self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]計算獲取當前batch的均值和方差,取修正后的K次迭代數據的平均即可。
# 利用收集到的一定窗口長度的均值和平方均值,計算當前均值和方差 sigma2_all = meanx2_all - torch.pow(mu_all, 2) re_mu_all = mu_all.clone() re_meanx2_all = meanx2_all.clone() re_mu_all[sigma2_all < 0] = 0 re_meanx2_all[sigma2_all < 0] = 0 count = (sigma2_all >= 0).sum(dim=0).float() mu = re_mu_all.sum(dim=0) / count # 平均操作 sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)均值和方差使用過程,和標準BN沒有區別。
# 標準化過程,和原始BN沒有區別 y = y - mu.view(-1, 1) if self.out_p: # 僅僅控制開平方的位置y = y / (sigma2.view(-1, 1) + self.eps) ** .5 else:y = y / (sigma2.view(-1, 1) ** .5 + self.eps)最后再理解一下:mu_0是當前batch統計獲取的均值,mu_1是上一batch統計獲取的均值。 當前batch計算BN的時候也想利用到mu_1,但是統計mu_1的時候利用到網絡的權重也是上一次的,直接使用肯定有問題,所以本文使用泰勒公式估計出mu_1在當前權重下應該是什么樣子。方差估計同理。
實驗效果:
這里的Naive CBN 是上一篇論文BRN的做法,可以認為是CBN不使用泰勒估計的一種特例。在batchsize下降的過程中,CBN指標依然堅挺,甚至超過了GN(不過也側面反應了GN確實厲害)。而原始BN和其改進版BRN在batchsize更小的時候都不太work了。
總結
以上是生活随笔為你收集整理的Batchsize不够大,如何发挥BN性能?探讨神经网络在小Batch下的训练方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 通道注意力新突破!从频域角度出发,浙大提
- 下一篇: 开源项目|基于darknet实现量化感知