pytorch的两个函数 tensor.detach(),tensor.detach_(),tensor.clone() 的作用和区别
前言:當我們在訓練網絡的時候可能希望保持一部分的網絡參數不變,只對其中一部分的參數進行調整;或者值訓練部分分支網絡,并不讓其梯度對主網絡的梯度造成影響,這時候我們就需要使用detach()函數來切斷一些分支的反向傳播。
1 tensor.detach()
返回一個新的tensor,從當前計算圖中分離下來的,但是仍指向原變量的存放位置,不同之處只是requires_grad為false,得到的這個tensor永遠不需要計算其梯度,不具有grad。
即使之后重新將它的requires_grad置為true,它也不會具有梯度grad
這樣我們就會繼續使用這個新的tensor進行計算,后面當我們進行反向傳播時,到該調用detach()的tensor就會停止,不能再繼續向前進行傳播
注意:使用detach返回的tensor和原始的tensor共同一個內存,即一個修改另一個也會跟著改變。
比如正常的例子是:
import torcha = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) # None out = a.sigmoid() out.sum().backward() # tensor([0.1966, 0.1050, 0.0452]) print(a.grad)1.1 當使用detach()分離tensor但是沒有更改這個tensor時,并不會影響backward():
import torcha = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) # None out = a.sigmoid() print(out) # tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)# 添加detach(),c的requires_grad為False c = out.detach() print(c) # tensor([0.7311, 0.8808, 0.9526])# 這個時候沒有對c進行更改,所以不會影響backward() out.sum().backward() print(a.grad) # tensor([0.1966, 0.1050, 0.0452])從上可見tensor c是由out分離得到的,但是我也沒有去改變這個c,這個時候依然對原來的out求導是不會有錯誤的,即
c,out之間的區別是c是沒有梯度的,out是有梯度的,但是需要注意的是下面兩種情況是匯報錯的
1.2 當使用detach()分離tensor,然后用這個分離出來的tensor去求導數,會影響backward(),會出現錯誤
import torcha = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) # None out = a.sigmoid() print(out) # tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)# 添加detach(),c的requires_grad為False c = out.detach() print(c) # tensor([0.7311, 0.8808, 0.9526])# 使用新生成的Variable進行反向傳播 c.sum().backward() print(a.grad)''' Traceback (most recent call last):File "/opt/data/private/lhl/SSL/uncertainty_maps/main_detach.py", line 13, in <module>c.sum().backward()File "/root/anaconda3/envs/py379/lib/python3.7/site-packages/torch/tensor.py", line 221, in backwardtorch.autograd.backward(self, gradient, retain_graph, create_graph)File "/root/anaconda3/envs/py379/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backwardallow_unreachable=True) # allow_unreachable flag RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn '''1.3 當使用detach()分離tensor并且更改這個tensor時,即使再對原來的out求導數,會影響backward(),會出現錯誤
如果此時對c進行了更改,這個更改會被autograd追蹤,在對out.sum()進行backward()時也會報錯,因為此時的值進行backward()得到的梯度是錯誤的:
import torcha = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) # None out = a.sigmoid() print(out) # tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)# 添加detach(),c的requires_grad為False c = out.detach() print(c) # tensor([0.7311, 0.8808, 0.9526]) c.zero_() # 使用inplace函數對其進行修改# 這個時候會發現對c進行更改,會影響backward c.sum().backward() print(a.grad)''' Traceback (most recent call last):File "/opt/data/private/lhl/SSL/uncertainty_maps/main_detach.py", line 14, in <module>c.sum().backward()File "/root/anaconda3/envs/py379/lib/python3.7/site-packages/torch/tensor.py", line 221, in backwardtorch.autograd.backward(self, gradient, retain_graph, create_graph)File "/root/anaconda3/envs/py379/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backwardallow_unreachable=True) # allow_unreachable flag RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn '''2 tensor.detach_()
將一個tensor從創建它的圖中分離,并把它設置成葉子tensor
其實就相當于變量之間的關系本來是x -> m -> y,這里的葉子tensor是x,但是這個時候對m進行了m.detach_()操作,其實就是進行了兩個操作:
- 將m的grad_fn的值設置為None,這樣m就不會再與前一個節點x關聯,這里的關系就會變成x, m -> y,此時的m就變成了葉子結點
- 然后會將m的requires_grad設置為False,這樣對y進行backward()時就不會求m的梯度
總結:其實detach()和detach_()很像,兩個的區別就是detach_()是對本身的更改,detach()則是生成了一個新的tensor
比如x -> m -> y中如果對m進行detach(),后面如果反悔想還是對原來的計算圖進行操作還是可以的
但是如果是進行了detach_(),那么原來的計算圖也發生了變化,就不能反悔了
3 tensor.clone()
clone(memory_format=torch.preserve_format)→ Tensor
返回tensor的拷貝,返回的新tensor和原來的tensor具有同樣的大小和數據類型。
原tensor的requires_grad=True
clone()返回的tensor是中間節點,梯度會流向原tensor,即返回的tensor的梯度會疊加在原tensor上
參考資料
總結
以上是生活随笔為你收集整理的pytorch的两个函数 tensor.detach(),tensor.detach_(),tensor.clone() 的作用和区别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: qq群的表设计探究
- 下一篇: javaboot+es