pytorch模型加载测试_pytorch模型加载方法汇总
Pytorch有很多方便易用的包,今天要談的是torchvision包,它包括3個子包,分別是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分別是預(yù)定義好的數(shù)據(jù)集(比如MNIST、CIFAR10等)、預(yù)定義好的經(jīng)典網(wǎng)絡(luò)結(jié)構(gòu)(比如AlexNet、VGG、ResNet等)和預(yù)定義好的數(shù)據(jù)增強方法(比如Resize、ToTensor等)。這些方法可以直接調(diào)用,簡化我們建模的過程,也可以作為我們學習或構(gòu)建新的模型的參考。
本文,我們講述的是models,且只談模型的加載。models這個包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的網(wǎng)絡(luò)結(jié)構(gòu),并且提供了預(yù)訓(xùn)練模型,可以通過簡單調(diào)用來讀取網(wǎng)絡(luò)結(jié)構(gòu)和預(yù)訓(xùn)練模型。
模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models
官方文檔:https://pytorch.org/docs/master/torchvision/models.html
我將加載的方法簡單總結(jié)為以下四種:
1.直接加載預(yù)訓(xùn)練模型
1 importtorchvision.models as models2
3 resnet50 = models.resnet50(pretrained=True)
這樣就導(dǎo)入了resnet50的預(yù)訓(xùn)練模型了。
如果只需要網(wǎng)絡(luò)結(jié)構(gòu),不需要用預(yù)訓(xùn)練模型的參數(shù)來初始化,那么就是:
model =torchvision.models.resnet50(pretrained=False)
或者把resnet復(fù)制到自己的目錄下,新建個model文件夾
可以參考下面的貓狗大戰(zhàn)入門算法入門
https://github.com/JackwithWilshere/Kaggle-Dogs_vs_Cats_PyTorch
2.修改某一層
以resnet為例,默認的是ImageNet的1000類,比如我們要做二分類,分類貓和狗
1 resnet.fc = nn.Linear(2048, 2) #resnet 第一層卷積的卷積核是7,我們可能想改成5,那么可以通過以下方法修改:
2
3 #未經(jīng)試驗,修改需要有理論依據(jù),計算featuremap維度使之匹配。
4 resnet.conv1 = nn.Conv2d(3, 64,kernel_size=5, stride=2, padding=3, bias=False)
3.加載部分預(yù)訓(xùn)練模型
對于具體的任務(wù),很難保證模型和公開的模型完全一樣,但是預(yù)訓(xùn)練模型的參數(shù)確實有助于提高訓(xùn)練的準確率,為了結(jié)合二者的優(yōu)點,就需要我們加載部分預(yù)訓(xùn)練模型。
1 #加載model,model是自己定義好的模型
resnet50 = models.resnet50(pretrained=True)
pretrained_dict =resnet50.state_dict()
model =Net(...)4 5 #讀取參數(shù) 6 model_dict =model.state_dict()
9 #將pretrained_dict里不屬于model_dict的鍵剔除掉
10 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k inmodel_dict} #更新現(xiàn)有的model_dict 13 model_dict.update(pretrained_dict) #這一塊更新的什么?? #加載我們真正需要的state_dict 16 model.load_state_dict(model_dict)
4. 加載自己的模型
其實這個是保存和恢復(fù)模型,比如我們訓(xùn)練好的模型保存,然后加載用于測試。
方法一(推薦):
第一種方法也是官方推薦的方法,只保存和恢復(fù)模型中的參數(shù)(權(quán)重數(shù)值)。
使用這種方法,我們需要自己導(dǎo)入模型的結(jié)構(gòu)信息。
(1)保存
1 torch.save(model.state_dict(), PATH)2
3 #example
4 torch.save(resnet50.state_dict(),'ckp/model.pth')
(2)恢復(fù)
1 model = ModelClass(*args, **kwargs)2 model.load_state_dict(torch.load(PATH))3
4 #example
5 resnet=resnet50(pretrained=True)6 resnet.load_state_dict(torch.load('ckp/model.pth'))
方法二:
使用這種方法,將會同時保存模型的參數(shù)和結(jié)構(gòu)信息到模型文件中。
(1)保存
torch.save (the_model, PATH)
(2)恢復(fù)
torch.load (the_model, PATH)
創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎勵來咯,堅持創(chuàng)作打卡瓜分現(xiàn)金大獎總結(jié)
以上是生活随笔為你收集整理的pytorch模型加载测试_pytorch模型加载方法汇总的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: mysql临时表空间_MySQL 5.7
- 下一篇: 查看本机所有请求_【松勤教育】Fiddl