Batch Normalization原理及pytorch的nn.BatchNorm2d函数
下面通過(guò)舉個(gè)例子來(lái)說(shuō)明Batch Normalization的原理,我們假設(shè)在網(wǎng)絡(luò)中間經(jīng)過(guò)某些卷積操作之后的輸出的feature map的尺寸為4×3×2×2,4為batch的大小,3為channel的數(shù)目,2×2為feature map的長(zhǎng)寬
整個(gè)BN層的運(yùn)算過(guò)程如下圖:
上圖中,batch size一共是4, 對(duì)于每一個(gè)batch的feature map的size是3×2×2?
對(duì)于所有batch中的同一個(gè)channel的元素進(jìn)行求均值與方差,比如上圖,對(duì)于所有的batch,都拿出來(lái)最后一個(gè)channel,一共有4×2×2=16個(gè)元素
然后求區(qū)這16個(gè)元素的均值與方差。求取完了均值與方差之后,對(duì)于這16個(gè)元素中的每個(gè)元素進(jìn)行減去求取得到的均值,并除以方差,然后乘以gamma加上beta,公式如下:
?因?yàn)榍笕〉木蹬c方差是對(duì)于所有batch中的同一個(gè)channel進(jìn)行求取,batch normalization中的batch體現(xiàn)在這個(gè)地方
在pytorch求取batch normalization的函數(shù)是nn.BatchNorm2d(),其傳入?yún)?shù)是channels數(shù),例如上面的例子中,
nn.BatchNorm2d(3)?
總結(jié)
以上是生活随笔為你收集整理的Batch Normalization原理及pytorch的nn.BatchNorm2d函数的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Python中的字典dict
- 下一篇: OpenCV学习笔记(十七):图像修补: