Pytorch学习- 小型知识点汇总 unsqueeze()/squeeze() 和 .max() 等等
1. unsqueeze(input, dim, out=None)函數(shù) - 升維作用
參考鏈接
在指定的地方上增加一個維度
0(-2) [行擴展]: 表示在張量最外層增加一個中括號變成第一維
1(-1) [列擴展]:表示
2. squeeze(input,dim,out=None) 降維函數(shù)
將輸入張量形狀中的1 去除并返回。 如果輸入是形如(A×1×B×1×C×1×D),那么輸出形狀就為: (A×B×C×D)
小例子
如果是一個列表的tensor例如x變量想要轉(zhuǎn)換成相同維度的tensor可以采用如下方式:
1)循環(huán)遍歷列表中每個張量s,先使用unsqueeze(0)將每個張量s升維。
形狀由torch.Size([3])變?yōu)閠orch.Size([1, 3])
【變化前:tensor([0, 1, 2]) 變化后:tensor([[0, 1, 2]])】
2)同時使用torch.cat()將其拼接起來 dim=0 表示橫向拼接,否則豎向拼接
dim = 0 結(jié)果:
dim = 1 結(jié)果:
tensor([[0, 1, 2, 1, 0, 2, 1, 2, 0, 2, 1, 0]]) >>> import torch >>> x = [torch.tensor([0,1,2]),torch.tensor([1,0,2]),torch.tensor([1,2,0]),torch.tensor([2,1,0]),] >>> x [tensor([0, 1, 2]), tensor([1, 0, 2]), tensor([1, 2, 0]), tensor([2, 1, 0])] >>> x = torch.cat([s.unsqueeze(0) for s in l],0) >>> x tensor([[0, 1, 2],[1, 0, 2],[1, 2, 0],[2, 1, 0]])3. max()的用法
更加詳細參見我的另一篇文章:Pytorch學習-torch.max()和min()深度解析
non_final_next_states.max(1)[1].detach()
# 行維度 .max(1)[0] 返回values的最大值列表 .max(1)[1]返回最大值index列表
# 列維度 .max(0)[0] 返回values的最大值列表 .max(0)[1]返回最大值index列表
4. detach() 和detach_()
參考鏈接
torch.detach() - 返回一個新的沒有梯度的tensor [生成一個新的tensor]
返回一個新的tensor,從當前計算圖中分離下來的,但是仍指向原變量的存放位置,不同之處只是requires_grad為false,得到的這個tensor永遠不需要計算其梯度,不具有g(shù)rad。
即使之后重新將它的requires_grad置為true,它也不會具有梯度grad
torch.detach_() - 直接修改該tensor[對其本身的更改],將其設置為無自動計算梯度的張量
將一個tensor從創(chuàng)建它的圖中分離,并把它設置成葉子tensor
5. torch.Tensor和torch.tensor的區(qū)別
參考
在Pytorch中,Tensor和tensor都用于生成新的張量。
torch.Tensor() 生成單精度浮點型張量
- torch.Tensor()是Python類,更明確的說,是默認張量類型torch.FloatTensor()的別名,torch.Tensor([1,2]) 會調(diào)用Tensor類的構(gòu)造函數(shù)__init__,生成單精度浮點類型的張量。
torch.tensor() 根據(jù)原始data生成對應類型的張量
torch.tensor()僅僅是Python的函數(shù),函數(shù)原型是:
torch.tensor(data, dtype=None, device=None, requires_grad=False)其中data可以是:list, tuple, array, scalar等類型。
torch.tensor()可以從data中的數(shù)據(jù)部分做拷貝(而不是直接引用),根據(jù)原始數(shù)據(jù)類型生成相應的torch.LongTensor,torch.FloatTensor,torch.DoubleTensor。
5.torch.cat() 的用法
參考鏈接
總結(jié)
以上是生活随笔為你收集整理的Pytorch学习- 小型知识点汇总 unsqueeze()/squeeze() 和 .max() 等等的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch学习-Task1
- 下一篇: Pytorch学习 - Task6 Py