PyTorch的torch.cat
-
字面理解:torch.cat是將兩個張量(tensor)拼接在一起,cat是concatnate的意思,即拼接,聯系在一起。
-
例子理解
import torch
A=torch.ones(2,3) #2x3的張量(矩陣)
A
tensor([[ 1., 1., 1.],
[ 1., 1., 1.]])
B=2torch.ones(4,3)#4x3的張量(矩陣)
B
tensor([[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]])
C=torch.cat((A,B),0)#按維數0(行)拼接
C
tensor([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]])
C.size()
torch.Size([6, 3])
D=2torch.ones(2,4) #2x4的張量(矩陣)
C=torch.cat((A,D),1)#按維數1(列)拼接
C
tensor([[ 1., 1., 1., 2., 2., 2., 2.],
[ 1., 1., 1., 2., 2., 2., 2.]])
C.size()
torch.Size([2, 7])
上面給出了兩個張量A和B,分別是2行3列,4行3列。即他們都是2維張量。因為只有兩維,這樣在用torch.cat拼接的時候就有兩種拼接方式:按行拼接和按列拼接。即所謂的維數0和維數1.
C=torch.cat((A,B),0)就表示按維數0(行)拼接A和B,也就是豎著拼接,A上B下。此時需要注意:列數必須一致,即維數1數值要相同,這里都是3列,方能列對齊。拼接后的C的第0維是兩個維數0數值和,即2+4=6.
C=torch.cat((A,B),1)就表示按維數1(列)拼接A和B,也就是橫著拼接,A左B右。此時需要注意:行數必須一致,即維數0數值要相同,這里都是2行,方能行對齊。拼接后的C的第1維是兩個維數1數值和,即3+4=7.
從2維例子可以看出,使用torch.cat((A,B),dim)時,除拼接維數dim數值可不同外其余維數數值需相同,方能對齊。
3.實例
在深度學習處理圖像時,常用的有3通道的RGB彩色圖像及單通道的灰度圖。張量size為cxhxw,即通道數x圖像高度x圖像寬度。在用torch.cat拼接兩張圖像時一般要求圖像大小一致而通道數可不一致,即h和w同,c可不同。當然實際有3種拼接方式,另兩種好像不常見。比如經典網絡結構:U-Net
里面用到4次torch.cat,其中copy and crop操作就是通過torch.cat來實現的。可以看到通過上采樣(up-conv 2x2)將原始圖像h和w變為原來2倍,再和左邊直接copy過來的同樣h,w的圖像拼接。這樣做,可以有效利用原始結構信息。
4.總結
使用torch.cat((A,B),dim)時,除拼接維數dim數值可不同外其余維數數值需相同,方能對齊。
總結
以上是生活随笔為你收集整理的PyTorch的torch.cat的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 查看分析网络层次
- 下一篇: pycharm连接远程服务器并进行代码上