额外参数_Pytorch获取模型参数情况的方法
分享人工智能技術干貨,專注深度學習與計算機視覺領域!
相較于Tensorflow,Pytorch一開始就是以動態圖構建神經網絡圖的,其獲取模型參數的方法也比較容易,既可以根據其內建接口自己寫代碼獲取模型參數情況,也可以借助第三方庫來獲取模型參數情況,下面,就讓我們一起來了解Pytorch獲取模型參數情況的這兩種方法!
Pytorch依據其內建接口自己寫代碼獲取模型參數情況,我們主要是借助該框架提供的模型parameters()接口并獲取對應參數的size來實現的,對于該參數是否屬于可訓練參數,那么我們可以依據Pytorch提供的requires_grad標志位來進行判斷,具體方法如下代碼所示:
# 定義總參數量、可訓練參數量及非可訓練參數量變量 Total_params = 0 Trainable_params = 0 NonTrainable_params = 0# 遍歷model.parameters()返回的全局參數列表 for param in model.parameters():mulValue = np.prod(param.size()) # 使用numpy prod接口計算參數數組所有元素之積Total_params += mulValue # 總參數量if param.requires_grad:Trainable_params += mulValue # 可訓練參數量else:NonTrainable_params += mulValue # 非可訓練參數量print(f'Total params: {Total_params}') print(f'Trainable params: {Trainable_params}') print(f'Non-trainable params: {NonTrainable_params}')如無特殊設定,一般來說,因為我們是直接獲取的model網絡參數,因此很少有不可訓練參數,往往NonTrainable_params輸出結果是0。
這里的第三方庫是指torchsummary,欲要使用該庫,首先我們得安裝它,命令如下:
pip install torchsummary然后,引入該庫的summary方法:
from torchsummary import summary最后,直接調用一條命令即可獲取到Pytorch模型參數情況:
summary(model, input_size=(ch, h, w), batch_size=-1)這里的ch是指輸入張量的channel數量,h表示輸入張量的高,w表示輸入張量的寬。
我們從以上代碼可以看到,借助第三方庫torchsummary來獲取Pytorch的模型參數情況非常之簡便,只需確認好輸入圖像shape即可,那么,torchsummary的輸出是如何的呢?
上圖是應用torchsummary獲得輸出結果的一個示例,這與Tensorflow V2.x及其之后的版本的模型summary()輸出是差不多的,輸出信息里也是有各個類別的參數量情況、每層網絡的參數量、額外的層名稱及其輸出shape大小,此外,torchsummary庫還為我們計算了輸入大小、模型參數大小及前向/反向傳播參數量大小,可謂信息非常細致,這極大地方便了我們查看Pytorch模型的構造情況。
除了上述兩種獲取Pytorch模型參數情況的方法,我們當然也可以直接使用model.state_dict()接口獲取Pytorch網絡參數,但是此種方法打印出來的信息結構非常混亂,也沒有為我們進行有效的信息整理,因此很不建議該方法。
總結
以上是生活随笔為你收集整理的额外参数_Pytorch获取模型参数情况的方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: java 使用.aar_java -
- 下一篇: java 配置dbcp_java –