日韩av黄I国产麻豆传媒I国产91av视频在线观看I日韩一区二区三区在线看I美女国产在线I麻豆视频国产在线观看I成人黄色短片

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 >

PyTorch基础(13)-- torch.nn.Unfold()方法

發布時間:2025/3/15 28 豆豆
生活随笔 收集整理的這篇文章主要介紹了 PyTorch基础(13)-- torch.nn.Unfold()方法 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

前言

最近在看新論文的過程中,發現新論文中的代碼非常簡潔,只用了unfold和fold方法便高效的將論文的思想表達出,因此學習記錄一下unfold和fold方法。

一、方法詳解

  • 方法
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
  • parameters
    • kernel_size (int or tuple) – 滑動窗口的size

    • stride (int or tuple, optional) – 空間維度上滑動的步長,默認步長為1

    • padding (int or tuple, optional) – implicit zero padding to be added on both sides of input. Default: 0

    • dilation (int or tuple, optional) – 空洞卷積的擴充率,默認為1

  • 釋義:提取滑動窗口滑過的所有值,例如下面的例子中,
[[ 0.4009, 0.6350, -0.5197, 0.8148, -0.7235], [-1.2102, 0.4621, -0.3421, -0.9261, -2.8376], [-1.5553, 0.1713, 0.6820, -2.0880, -0.0204], [ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108], [ 0.1459, -0.4568, 1.0039, -1.2385, -1.4467]]

kernel size =3 的窗口滑過,會首先記錄

[[ 0.4009, 0.6350, -0.5197, -1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820],[ 0.6350, -0.5197, 0.8148, 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880],[-0.5197, 0.8148, -0.7235, -0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204],[-1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510],[ 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367],[-0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108],[-1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510, 0.1459, -0.4568, 1.0039],[ 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568, 1.0039, -1.2385],[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108, 1.0039, -1.2385, -1.4467]]
  • Note:unfold方法的輸入只能是4維的,即(N,C,H,W)

二、如何計算輸出的size

  • 栗子
import torch import torch.nn as nn if __name__ == '__main__':x = torch.randn(2, 3, 5, 5)print(x)unfold = nn.Unfold(2)y = unfold(x)print(y.size())print(y)
  • 運行結果
torch.Size([2, 12, 16])

接下來,我們一步一步分析這個結果是怎么計算出來的!

首先,要知道的是,我們的輸入必須是4維的,即(B,C,H,W),其中,B表示Batch size;C代表通道數;H代表feature map的高;W表示feature map的寬。首先,我們假設經過Unfolder處理之后的size為(B,h,w)。然后我們需要計算h(即輸出的高),計算公式如下所示:

這里是引用舉個栗子:假設輸入通道數為3,kernel size為(2,2),圖片最常見的通道數為3(所以我們拿來舉例),經過Unfolder方法后,輸出的高變為322=12,即輸出的H為12。

計算完成之后,我們需要計算w,計算公式如下所示:

其中,d代表的是空間的所有維度數,例如空間維度為(H,W),則d=2。下面通過舉例,我們來計算輸出的w。

舉個栗子:如果輸入的H、W分別為5,kernel size為2,則輸出的w為

4*4=16,故最終的輸出size為[2,12,16]。

三、案例

  • 案例
import torch import torch.nn as nn if __name__ == '__main__':x = torch.randn(1, 3, 5, 5)print(x)unfold = nn.Unfold(kernel_size=3)output = unfold(x)print(output, output.size())
  • 運行結果
tensor([[[[ 0.4009, 0.6350, -0.5197, 0.8148, -0.7235],[-1.2102, 0.4621, -0.3421, -0.9261, -2.8376],[-1.5553, 0.1713, 0.6820, -2.0880, -0.0204],[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108],[ 0.1459, -0.4568, 1.0039, -1.2385, -1.4467]],[[-0.9973, -0.7601, -0.2161, 1.2120, -0.3036],[-0.7279, 0.0833, -0.8886, -0.9168, 0.7503],[-0.6748, 0.7064, 0.6903, -1.0447, 0.8688],[-0.5230, -1.2308, -0.3932, 1.2521, -0.2523],[-0.3930, 0.6452, 0.1690, 0.3744, 0.2015]],[[ 0.6403, 1.3915, -1.9529, 0.2899, -0.8897],[-0.1720, 1.0843, -1.0177, -1.7480, -0.5217],[-0.9648, -0.0867, -0.2926, 0.3010, 0.3192],[ 0.1181, -0.2218, 0.0766, 0.5914, -0.8932],[-0.4508, -0.3964, 1.1163, 0.6776, -0.8948]]]]) tensor([[[ 0.4009, 0.6350, -0.5197, -1.2102, 0.4621, -0.3421, -1.5553,0.1713, 0.6820],[ 0.6350, -0.5197, 0.8148, 0.4621, -0.3421, -0.9261, 0.1713,0.6820, -2.0880],[-0.5197, 0.8148, -0.7235, -0.3421, -0.9261, -2.8376, 0.6820,-2.0880, -0.0204],[-1.2102, 0.4621, -0.3421, -1.5553, 0.1713, 0.6820, 1.1419,-0.4881, -0.9510],[ 0.4621, -0.3421, -0.9261, 0.1713, 0.6820, -2.0880, -0.4881,-0.9510, -0.0367],[-0.3421, -0.9261, -2.8376, 0.6820, -2.0880, -0.0204, -0.9510,-0.0367, -0.8108],[-1.5553, 0.1713, 0.6820, 1.1419, -0.4881, -0.9510, 0.1459,-0.4568, 1.0039],[ 0.1713, 0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568,1.0039, -1.2385],[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108, 1.0039,-1.2385, -1.4467],[-0.9973, -0.7601, -0.2161, -0.7279, 0.0833, -0.8886, -0.6748,0.7064, 0.6903],[-0.7601, -0.2161, 1.2120, 0.0833, -0.8886, -0.9168, 0.7064,0.6903, -1.0447],[-0.2161, 1.2120, -0.3036, -0.8886, -0.9168, 0.7503, 0.6903,-1.0447, 0.8688],[-0.7279, 0.0833, -0.8886, -0.6748, 0.7064, 0.6903, -0.5230,-1.2308, -0.3932],[ 0.0833, -0.8886, -0.9168, 0.7064, 0.6903, -1.0447, -1.2308,-0.3932, 1.2521],[-0.8886, -0.9168, 0.7503, 0.6903, -1.0447, 0.8688, -0.3932,1.2521, -0.2523],[-0.6748, 0.7064, 0.6903, -0.5230, -1.2308, -0.3932, -0.3930,0.6452, 0.1690],[ 0.7064, 0.6903, -1.0447, -1.2308, -0.3932, 1.2521, 0.6452,0.1690, 0.3744],[ 0.6903, -1.0447, 0.8688, -0.3932, 1.2521, -0.2523, 0.1690,0.3744, 0.2015],[ 0.6403, 1.3915, -1.9529, -0.1720, 1.0843, -1.0177, -0.9648,-0.0867, -0.2926],[ 1.3915, -1.9529, 0.2899, 1.0843, -1.0177, -1.7480, -0.0867,-0.2926, 0.3010],[-1.9529, 0.2899, -0.8897, -1.0177, -1.7480, -0.5217, -0.2926,0.3010, 0.3192],[-0.1720, 1.0843, -1.0177, -0.9648, -0.0867, -0.2926, 0.1181,-0.2218, 0.0766],[ 1.0843, -1.0177, -1.7480, -0.0867, -0.2926, 0.3010, -0.2218,0.0766, 0.5914],[-1.0177, -1.7480, -0.5217, -0.2926, 0.3010, 0.3192, 0.0766,0.5914, -0.8932],[-0.9648, -0.0867, -0.2926, 0.1181, -0.2218, 0.0766, -0.4508,-0.3964, 1.1163],[-0.0867, -0.2926, 0.3010, -0.2218, 0.0766, 0.5914, -0.3964,1.1163, 0.6776],[-0.2926, 0.3010, 0.3192, 0.0766, 0.5914, -0.8932, 1.1163,0.6776, -0.8948]]]) torch.Size([1, 27, 9])

覺得寫的不錯的話,歡迎點贊+評論+收藏,這對我幫助很大!

總結

以上是生活随笔為你收集整理的PyTorch基础(13)-- torch.nn.Unfold()方法的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。