Pytorch教程之torch.mm、torch.bmm、torch.matmul、masked_fill
文章目錄
- 1、簡(jiǎn)介
- 2、torch.mm
- 3、torch.bmm
- 4、torch.matmul
- 5、masked_fill
1、簡(jiǎn)介
這幾天正在看NLP中的注意力機(jī)制,代碼中涉及到了一些關(guān)于張量矩陣乘法和填充一些代碼,這里積累一下。主要參考了pytorch2.0的官方文檔。
①torch.mm(input,mat2,*,out=None)
②torch.bmm(input,mat2,*,out=None)
③torch.matmul(input, other, *, out=None)
④Tensor.masked_fill
2、torch.mm
torch.mm語法為:
torch.mm(input, mat2, *, out=None) → Tensor就是矩陣的乘法。如果輸入input是(n,m),mat2是(m, p),則輸出為(n, p)。
示例:
3、torch.bmm
torch.bmm語法為:
torch.bmm(input, mat2, *, out=None) → Tensor- 功能:對(duì)存儲(chǔ)在input和mat2矩陣中的批數(shù)量的矩陣進(jìn)行乘積。
- 要求:input矩陣和mat2必須是三維的張量,且第一個(gè)維度即batch維度必須一樣。
- 舉例:如果input是一個(gè)(b, n , m)的張量,mat2是一個(gè)(b, m, p)張量,則輸出形狀為(b, n, p)
示例:
input = torch.randn(10, 3, 4) mat2 = torch.randn(10, 4, 5) res = torch.bmm(input, mat2) res.size() -->torch.Size([10, 3, 5])解讀:實(shí)際上刻畫的就是一組矩陣與另一組張量矩陣的乘積,至于一組有多少個(gè)矩陣,由input和mat2的第一個(gè)輸入維度決定,上述代碼第一個(gè)維度為10,就代表著10個(gè)形狀為(3, 4)的矩陣與10個(gè)形狀為(4, 5)的矩陣分別對(duì)應(yīng)相乘,得到10個(gè)形狀為(3, 5)的矩陣。
4、torch.matmul
torch.matmul語法為:
torch.matmul(input, other, *, out=None) → Tensor該函數(shù)刻畫的是兩個(gè)張量的乘積,且計(jì)算過程與張量的維度密切相關(guān)。
① 如果張量是一維的,輸出結(jié)果是點(diǎn)乘,是一個(gè)標(biāo)量。
a = torch.tensor([1,2,4]) b = torch.tensor([2,5,6]) print(torch.matmul(a, b)) print(a.shape) --> tensor(36) -->torch.Size([3])注意:張量a.shape顯示的是torch.Size([3]),只有一個(gè)維度,3是指這個(gè)維度中有3個(gè)數(shù)。
② 如果兩個(gè)張量都是二維的,執(zhí)行的是矩陣的乘法。
由上述示例可知,如果兩個(gè)張量均為2維,那么其運(yùn)算和torch.mm是一樣的。
③如果第一個(gè)參數(shù)input是1維的,第二個(gè)參數(shù)是二維的,那么在計(jì)算時(shí),在第一個(gè)參數(shù)前增加一個(gè)維度1,計(jì)算完畢之后再把這個(gè)維度去掉。
如上所示,a只有一個(gè)維度,在進(jìn)行計(jì)算時(shí),變成了(1, 3),則變成了(1, 3)乘以(3, 2),變成(1, 2),最后在去掉1這個(gè)維度。
④如果第一個(gè)參數(shù)是2維的,第二個(gè)參數(shù)是1維的,則返回矩陣-向量乘積。
矩陣乘以張量,就是矩陣中的每一行都與這個(gè)張量相乘,最終得到一個(gè)一維的,大小為3的結(jié)果。
⑤多個(gè)維度
- 如果兩個(gè)參數(shù)至少都是1維的,且有一個(gè)參數(shù)的維度N>2,則返回的是一個(gè)批矩陣的乘積(即把多出的那個(gè)維度看作batch即可,讓每個(gè)batch后的矩陣與后面的張量相乘即可)。
- 如果第一個(gè)參數(shù)是1維的,則在它的維度前加上1,以便批量矩陣相乘并在之后刪除。如果第二個(gè)參數(shù)是1維的,則將1追加到其維度,用于批處理矩陣倍數(shù),然后刪除。
- 舉例:如果input形狀是(j,1,n,n),other的張量形狀是(k,n,n),那么輸出張量的形狀將會(huì)是(j,k,n,n)。
- 如果input形狀是(j,1,n,m),other的張量形狀是(k,m,p),那么輸出張量的形狀將會(huì)是(j,k,n,p)。
仔細(xì)比較上述三個(gè)代碼塊,其最終的結(jié)果是一樣的。可以簡(jiǎn)單記為如果兩個(gè)維度不一致的話,多出的維度就看作是batch維,相當(dāng)于在低維度前面增加一個(gè)維度。
5、masked_fill
語法為:
Tensor.masked_fill_(mask, value)參數(shù):
- mask(BoolTensor):布爾掩碼
- value(float):用于填充的值。
mask是一個(gè)pytorch張量,元素是布爾值,value是要填充的值,填充規(guī)則是mask中取值為True的位置對(duì)應(yīng)與需要填充的張量中的位置用value填充。
a = torch.tensor([[0, 8],[ 6, 8],[ 7, 1] ])mask = torch.tensor([[ True, False],[False, False],[False, True] ]) b = a.masked_fill(mask, -1e9) print(b) -->tensor([[-1000000000, 8],[ 6, 8],[ 7, -1000000000]])總結(jié)
以上是生活随笔為你收集整理的Pytorch教程之torch.mm、torch.bmm、torch.matmul、masked_fill的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 国庆节堕落的日子
- 下一篇: 精品基于Uniapp+SSM实现的公园植