前言
最近在看新論文的過程中,發(fā)現(xiàn)新論文中的代碼非常簡潔,只用了unfold和fold方法便高效的將論文的思想表達(dá)出,因此學(xué)習(xí)記錄一下unfold和fold方法。
一、方法詳解
torch.nn.Unfold
(kernel_size, dilation
=1, padding
=0, stride
=1
)
-
parameters
-
kernel_size (int or tuple) – 滑動(dòng)窗口的size
-
stride (int or tuple, optional) – 空間維度上滑動(dòng)的步長,默認(rèn)步長為1
-
padding (int or tuple, optional) – implicit zero padding to be added on both sides of input. Default: 0
-
dilation (int or tuple, optional) – 空洞卷積的擴(kuò)充率,默認(rèn)為1
-
釋義:提取滑動(dòng)窗口滑過的所有值,例如下面的例子中,
[[ 0.4009, 0.6350, -0.5197, 0.8148, -0.7235
],
[-1.2102, 0.4621, -0.3421, -0.9261, -2.8376
],
[-1.5553, 0.1713, 0.6820, -2.0880, -0.0204
],
[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108
],
[ 0.1459, -0.4568, 1.0039, -1.2385, -1.4467
]]
kernel size =3 的窗口滑過,會(huì)首先記錄
[[ 0.4009, 0.6350, -0.5197, -1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820
],
[ 0.6350, -0.5197, 0.8148, 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880
],
[-0.5197, 0.8148, -0.7235, -0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204
],
[-1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510
],
[ 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367
],
[-0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108
],
[-1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510, 0.1459, -0.4568, 1.0039
],
[ 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568, 1.0039, -1.2385
],
[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108, 1.0039, -1.2385, -1.4467
]]
-
Note:unfold方法的輸入只能是4維的,即(N,C,H,W)
二、如何計(jì)算輸出的size
import torch
import torch
.nn
as nn
if __name__
== '__main__':x
= torch
.randn
(2, 3, 5, 5)print(x
)unfold
= nn
.Unfold
(2)y
= unfold
(x
)print(y
.size
())print(y
)
torch.Size
([2, 12, 16
])
接下來,我們一步一步分析這個(gè)結(jié)果是怎么計(jì)算出來的!
首先,要知道的是,我們的輸入必須是4維的,即(B,C,H,W),其中,B表示Batch size;C代表通道數(shù);H代表feature map的高;W表示feature map的寬。首先,我們假設(shè)經(jīng)過Unfolder處理之后的size為(B,h,w)。然后我們需要計(jì)算h(即輸出的高),計(jì)算公式如下所示:
這里是引用舉個(gè)栗子:假設(shè)輸入通道數(shù)為3,kernel size為(2,2),圖片最常見的通道數(shù)為3(所以我們拿來舉例),經(jīng)過Unfolder方法后,輸出的高變?yōu)?22=12,即輸出的H為12。
計(jì)算完成之后,我們需要計(jì)算w,計(jì)算公式如下所示:
其中,d代表的是空間的所有維度數(shù),例如空間維度為(H,W),則d=2。下面通過舉例,我們來計(jì)算輸出的w。
舉個(gè)栗子:如果輸入的H、W分別為5,kernel size為2,則輸出的w為
4*4=16,故最終的輸出size為[2,12,16]。
三、案例
import torch
import torch
.nn
as nn
if __name__
== '__main__':x
= torch
.randn
(1, 3, 5, 5)print(x
)unfold
= nn
.Unfold
(kernel_size
=3)output
= unfold
(x
)print(output
, output
.size
())
tensor
([[[[ 0.4009, 0.6350, -0.5197, 0.8148, -0.7235
],
[-1.2102, 0.4621, -0.3421, -0.9261, -2.8376
],
[-1.5553, 0.1713, 0.6820, -2.0880, -0.0204
],
[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108
],
[ 0.1459, -0.4568, 1.0039, -1.2385, -1.4467
]],
[[-0.9973, -0.7601, -0.2161, 1.2120, -0.3036
],
[-0.7279, 0.0833, -0.8886, -0.9168, 0.7503
],
[-0.6748, 0.7064, 0.6903, -1.0447, 0.8688
],
[-0.5230, -1.2308, -0.3932, 1.2521, -0.2523
],
[-0.3930, 0.6452, 0.1690, 0.3744, 0.2015
]],
[[ 0.6403, 1.3915, -1.9529, 0.2899, -0.8897
],
[-0.1720, 1.0843, -1.0177, -1.7480, -0.5217
],
[-0.9648, -0.0867, -0.2926, 0.3010, 0.3192
],
[ 0.1181, -0.2218, 0.0766, 0.5914, -0.8932
],
[-0.4508, -0.3964, 1.1163, 0.6776, -0.8948
]]]])
tensor
([[[ 0.4009, 0.6350, -0.5197, -1.2102, 0.4621, -0.3421, -1.5553,0.1713, 0.6820
],
[ 0.6350, -0.5197, 0.8148, 0.4621, -0.3421, -0.9261, 0.1713,0.6820, -2.0880
],
[-0.5197, 0.8148, -0.7235, -0.3421, -0.9261, -2.8376, 0.6820,-2.0880, -0.0204
],
[-1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820, 1.1419,-0.4881, -0.9510
],
[ 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880, -0.4881,-0.9510, -0.0367
],
[-0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204, -0.9510,-0.0367, -0.8108
],
[-1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510, 0.1459,-0.4568, 1.0039
],
[ 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568,1.0039, -1.2385
],
[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108, 1.0039,-1.2385, -1.4467
],
[-0.9973, -0.7601, -0.2161, -0.7279, 0.0833, -0.8886, -0.6748,0.7064, 0.6903
],
[-0.7601, -0.2161, 1.2120, 0.0833, -0.8886, -0.9168, 0.7064,0.6903, -1.0447
],
[-0.2161, 1.2120, -0.3036, -0.8886, -0.9168, 0.7503, 0.6903,-1.0447, 0.8688
],
[-0.7279, 0.0833, -0.8886, -0.6748, 0.7064, 0.6903, -0.5230,-1.2308, -0.3932
],
[ 0.0833, -0.8886, -0.9168, 0.7064, 0.6903, -1.0447, -1.2308,-0.3932, 1.2521
],
[-0.8886, -0.9168, 0.7503, 0.6903, -1.0447, 0.8688, -0.3932,1.2521, -0.2523
],
[-0.6748, 0.7064, 0.6903, -0.5230, -1.2308, -0.3932, -0.3930,0.6452, 0.1690
],
[ 0.7064, 0.6903, -1.0447, -1.2308, -0.3932, 1.2521, 0.6452,0.1690, 0.3744
],
[ 0.6903, -1.0447, 0.8688, -0.3932, 1.2521, -0.2523, 0.1690,0.3744, 0.2015
],
[ 0.6403, 1.3915, -1.9529, -0.1720, 1.0843, -1.0177, -0.9648,-0.0867, -0.2926
],
[ 1.3915, -1.9529, 0.2899, 1.0843, -1.0177, -1.7480, -0.0867,-0.2926, 0.3010
],
[-1.9529, 0.2899, -0.8897, -1.0177, -1.7480, -0.5217, -0.2926,0.3010, 0.3192
],
[-0.1720, 1.0843, -1.0177, -0.9648, -0.0867, -0.2926, 0.1181,-0.2218, 0.0766
],
[ 1.0843, -1.0177, -1.7480, -0.0867, -0.2926, 0.3010, -0.2218,0.0766, 0.5914
],
[-1.0177, -1.7480, -0.5217, -0.2926, 0.3010, 0.3192, 0.0766,0.5914, -0.8932
],
[-0.9648, -0.0867, -0.2926, 0.1181, -0.2218, 0.0766, -0.4508,-0.3964, 1.1163
],
[-0.0867, -0.2926, 0.3010, -0.2218, 0.0766, 0.5914, -0.3964,1.1163, 0.6776
],
[-0.2926, 0.3010, 0.3192, 0.0766, 0.5914, -0.8932, 1.1163,0.6776, -0.8948
]]]) torch.Size
([1, 27, 9
])
覺得寫的不錯(cuò)的話,歡迎點(diǎn)贊+評論+收藏,這對我?guī)椭艽?#xff01;
總結(jié)
以上是生活随笔為你收集整理的PyTorch基础(13)-- torch.nn.Unfold()方法的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。