pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)
autograd和動態計算圖可以說是pytorch中非常核心的部分,我們在之前的文章中提到:autograd其實就是反向求偏導的過程,而在求偏導的過程中,鏈式求導法則和雅克比矩陣是其實現的數學基礎;Tensor構成的動態計算圖是使用pytorch的實現的結構。
backward()函數
backward()是通過將參數(默認為1x1單位張量)通過反向圖追蹤所有對于該張量的操作,使用鏈式求導法則從根張量追溯到每個葉子節點以計算梯度。下圖描述了pytorch對于函數z = (a + b)(b - c)構建的計算圖,以及從根節點z到葉子節點a,b,c的求導過程:
注意:計算圖已經在前向傳遞過程中已經被動態創建了,反向傳播僅使用已存在的計算圖計算梯度并將其存儲在葉子節點中。
為了節約內存,在每一輪迭代完成后,計算圖就會被釋放,若需要多次調用backward()方法,則需要在使用時添加retain_graph=True,否則會報如下錯誤:
RuntimeError:?Trying to backward through the graph a second time, but the buffers have already been freed.
若我們在使用過程中,僅僅想求得某個節點的梯度,而非整個圖的梯度,則需要用到Tensor的.grad屬性,如下列代碼所示:
import torch# 創建計算圖x = torch.tensor(1.0, requires_grad = True)z = x ** 3# 計算梯度z.backward() print(x.grad.data)需要注意的是:當調用z.backward()時,將自動計算z.backward(torch.tensor(1.0)),其中 torch.tensor(1.0)是用于終止連式法則梯度乘法的外部梯度。可以將此作為輸入傳遞給MulBackward函數,以進一步計算x的梯度。
在上述的示例中,我們給出了標量對向量的求導過程,那么當向量對向量進行求導時呢?例如,需要計算梯度的張量x和y如下:
x = torch.tensor([0.0, 2.0, 8.0], requires_grad = True)y = torch.tensor([5.0 , 1.0 , 7.0], requires_grad = True)z = x * y此時調用z.backward()函數將會報如下錯誤:
RuntimeError: grad can be implicitly created only for scalar outputs
錯誤提示我們只能應用于標量輸出。若我們想對向量z進行梯度計算,先了解一下Jacobian矩陣。
Jacobian矩陣和向量
從數學角度上來講:雅克比矩陣是基于函數對所有變量一階偏導數的數值矩陣,當輸入個數等于輸出個數時又稱為雅克比行列式。
而autograd類在實際運用的過程中也是通過計算雅克比向量積實現對向量梯度的計算。簡單來說,雅可比矩陣是代表兩個向量的所有可能偏導數的矩陣,可以用于求一個向量相對于另一個向量的梯度。
注:在此過程中,PyTorch不會顯式構造整個Jacobian矩陣,而是直接計算Jacobian矢量積,這種計算方式更為簡便。
如果向量X = [x1,x2,… xn]通過函數f(X)= [f1,f2,… fn]計算其他向量,假設f對于x的每個一階偏導數都存在,則f(X)相對于X的梯度矩陣為:
假設待計算梯度的張量X為:X = [x1,x2,… xn](機器學習模型的權重),X可以進行一些運算以形成向量Y:Y = f(X)= [y1,y2,… ym]。然后,使用Y來計算標量損失l。假設向量v恰好是標量損失l相對于向量Y的梯度,則:
此時,向量v則被稱為grad_tensor,即梯度張量。并將其作為參數傳遞給backward()函數。為了獲得損失l相對于權重X的梯度,將Jacobian矩陣J與向量v相乘,得到最終梯度:
綜上所述,pytorch在使用計算圖求導的過程中整體可以分為以下兩種情況:
1. 若標量對向量求導,則可以直接調用backward()函數;
2. 若向量A對向量B求導,則先求得向量A對于向量B的Jacobian矩陣,并將其與grad_tensors對應的矩陣進行點乘計算得到最終梯度。
·? END? ·
RECOMMEND推薦閱讀?1. 效率提升的軟件大禮包
?2.?深度學習——入門PyTorch(一)
?3.?深度學習——入門PyTorch(二)
?4. PyTorch入門——autograd(一)
?5. PyTorch入門——autograd(二)
總結
以上是生活随笔為你收集整理的pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 哪种语言 连接 oracle,Go语言连
- 下一篇: 计算机视觉中的多视图几何_基于深度学习的