PyTorch 笔记(10)— Tensor 与 NumPy 相互转换、两种共享内存以及两者的广播法则
生活随笔
收集整理的這篇文章主要介紹了
PyTorch 笔记(10)— Tensor 与 NumPy 相互转换、两种共享内存以及两者的广播法则
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
Tensor 與 NumPy 有很高的相似性,彼此之間的互操作也非常簡單有效,需要注意的是 Tensor 與 NumPy 共享內存,由于 NumPy 歷史悠久,所以遇到 Tensor 不支持的操作時,可以先轉換成 NumPy ,處理后再轉換成 Tensor,轉換開銷很小。
1. Tensor 轉化為 NumPy
In [1]: import torch as tIn [2]: a = t.ones(5)In [3]: a
Out[3]: tensor([1., 1., 1., 1., 1.])In [4]: b = a.numpy()In [5]: b
Out[5]: array([1., 1., 1., 1., 1.], dtype=float32)
2. NumPy 轉化為 Tensor
In [6]: import numpy as npIn [7]: a = np.ones(5)In [8]: a
Out[8]: array([1., 1., 1., 1., 1.])In [9]: b =t.from_numpy(a)In [10]: b
Out[10]: tensor([1., 1., 1., 1., 1.], dtype=torch.float64)
值得注意的是,Torch 中的 Tensor 和 NumPy 中的 Array 共享內存位置,一個改變,另一個也同樣改變。注意使用的是 b.add_() 。
In [10]: b
Out[10]: tensor([1., 1., 1., 1., 1.], dtype=torch.float64)In [11]: b.add_(1)
Out[11]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)In [12]: a
Out[12]: array([2., 2., 2., 2., 2.])In [13]: b
Out[13]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
下面看使用 b.add() 發現 b 并沒有改變。
In [13]: b
Out[13]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)In [14]: b.add(2)
Out[14]: tensor([4., 4., 4., 4., 4.], dtype=torch.float64)In [15]: a
Out[15]: array([2., 2., 2., 2., 2.])In [16]: b
Out[16]: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
b.add_() 和 b.add() 的區別:
任何操作符都固定地在前面加上 _ 來表示替換。例如:y.copy_(x) ,y.t_(),都將改變 y。
3. PyTorch 廣播法則
當輸入數組的某個維度的長度為 1 時,計算時沿此維度復制擴充成一樣的形狀。
可以通過以下兩個函數組合手動實現廣播法則:
unsqueeze或者view: 為數據的某一維的形狀補 1,實現法則 1expand或者expand_as,重復數組,實現法則 3;該操作不會復制數組,所以不會占用額外空間
注意: repeat 實現和 expand 相類似的功能,但是 repeat 會把形同的數據復制多份,因此會占用額外的空間。
3.1 自動廣播法則
In [17]: a = t.ones(3,2)In [18]: a
Out[18]:
tensor([[1., 1.],[1., 1.],[1., 1.]])In [19]: b = t.zeros(2,3,1)In [20]: b
Out[20]:
tensor([[[0.],[0.],[0.]],[[0.],[0.],[0.]]])
可以看到 a 是二維的,而 b 是三維的,但是可以通過廣播法則直接進行相加計算。
In [23]: a + b
Out[23]:
tensor([[[1., 1.],[1., 1.],[1., 1.]],[[1., 1.],[1., 1.],[1., 1.]]])In [24]:
3.2 手動廣播法則
In [24]: a.unsqueeze(0).expand(2,3,2) + b.expand(2,3,2)
Out[24]:
tensor([[[1., 1.],[1., 1.],[1., 1.]],[[1., 1.],[1., 1.],[1., 1.]]])In [25]:
4. Numpy 廣播法則
- 讓所有輸入數組都向其中
shape最長的數組看齊,shape中不足的部分可通過在前面加 1 補齊; - 兩個數組要么在某一維度的長度一致,要么其中一個為 1,否則不能計算;
總結
以上是生活随笔為你收集整理的PyTorch 笔记(10)— Tensor 与 NumPy 相互转换、两种共享内存以及两者的广播法则的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch 笔记(09)— Tens
- 下一篇: PyTorch 笔记(11)— Tens