PyTorch模型读写、参数初始化、Finetune
使用了一段時(shí)間PyTorch,感覺愛不釋手(0-0),聽說現(xiàn)在已經(jīng)有C++接口。在應(yīng)用過程中不可避免需要使用Finetune/參數(shù)初始化/模型加載等。
模型保存/加載
1.所有模型參數(shù)
訓(xùn)練過程中,有時(shí)候會(huì)由于各種原因停止訓(xùn)練,這時(shí)候我們訓(xùn)練過程中就需要注意將每一輪epoch的模型保存(一般保存最好模型與當(dāng)前輪模型)。一般使用pytorch里面推薦的保存方法。該方法保存的是模型的參數(shù)。
#保存模型到checkpoint.pth.tar torch.save(model.module.state_dict(), ‘checkpoint.pth.tar’)對(duì)應(yīng)的加載模型方法為(這種方法需要先反序列化模型獲取參數(shù)字典,因此必須先load模型,再load_state_dict):
mymodel.load_state_dict(torch.load(‘checkpoint.pth.tar’))有了上面的保存后,現(xiàn)以一個(gè)例子說明如何在inference AND/OR resume train使用。
#保存模型的狀態(tài),可以設(shè)置一些參數(shù),后續(xù)可以使用 state = {'epoch': epoch + 1,#保存的當(dāng)前輪數(shù)'state_dict': mymodel.state_dict(),#訓(xùn)練好的參數(shù)'optimizer': optimizer.state_dict(),#優(yōu)化器參數(shù),為了后續(xù)的resume'best_pred': best_pred#當(dāng)前最好的精度,....,...}#保存模型到checkpoint.pth.tar torch.save(state, ‘checkpoint.pth.tar’) #如果是best,則復(fù)制過去 if is_best:shutil.copyfile(filename, directory + 'model_best.pth.tar')checkpoint = torch.load('model_best.pth.tar') model.load_state_dict(checkpoint['state_dict'])#模型參數(shù) optimizer.load_state_dict(checkpoint['optimizer'])#優(yōu)化參數(shù) epoch = checkpoint['epoch']#epoch,可以用于更新學(xué)習(xí)率等#有了以上的東西,就可以繼續(xù)重新訓(xùn)練了,也就不需要擔(dān)心停止程序重新訓(xùn)練。 train/eval .... ....上面是pytorch建議使用的方法,當(dāng)然還有第二種方法。這種方法靈活性不高,不推薦。
#保存 torch.save(mymodel,‘checkpoint.pth.tar’)#加載 mymodel = torch.load(‘checkpoint.pth.tar’)2.部分模型參數(shù)
在很多時(shí)候,我們加載的是已經(jīng)訓(xùn)練好的模型,而訓(xùn)練好的模型可能與我們定義的模型不完全一樣,而我們只想使用一樣的那些層的參數(shù)。
有幾種解決方法:
(1)直接在訓(xùn)練好的模型開始搭建自己的模型,就是先加載訓(xùn)練好的模型,然后再它基礎(chǔ)上定義自己的模型;
model_ft = models.resnet18(pretrained=use_pretrained) self.conv1 = model_ft.conv1 self.bn = model_ft.bn ... ...(2) 自己定義好模型,直接加載模型
#第一種方法: mymodelB = TheModelBClass(*args, **kwargs) # strict=False,設(shè)置為false,只保留鍵值相同的參數(shù) mymodelB.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)#第二種方法: # 加載模型 model_pretrained = models.resnet18(pretrained=use_pretrained)# mymodel's state_dict, # 如: conv1.weight # conv1.bias mymodelB_dict = mymodelB.state_dict()# 將model_pretrained的建與自定義模型的建進(jìn)行比較,剔除不同的 pretrained_dict = {k: v for k, v in model_pretrained.items() if k in mymodelB_dict} # 更新現(xiàn)有的model_dict mymodelB_dict.update(pretrained_dict)# 加載我們真正需要的state_dict mymodelB.load_state_dict(mymodelB_dict)# 方法2可能更直觀一些參數(shù)初始化
第二個(gè)問題是參數(shù)初始化問題,在很多代碼里面都會(huì)使用到,畢竟不是所有的都是有預(yù)訓(xùn)練參數(shù)。這時(shí)就需要對(duì)不是與預(yù)訓(xùn)練參數(shù)進(jìn)行初始化。pytorch里面的每個(gè)Tensor其實(shí)是對(duì)Variabl的封裝,其包含data、grad等接口,因此可以用這些接口直接賦值。這里也提供了怎樣把其他框架(caffe/tensorflow/mxnet/gluonCV等)訓(xùn)練好的模型參數(shù)直接賦值給pytorch.其實(shí)就是對(duì)data直接賦值。
pytorch提供了初始化參數(shù)的方法:
def weight_init(m):if isinstance(m,nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0,math.sqrt(2./n))elif isinstance(m,nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()但一般如果沒有很大需求初始化參數(shù),也沒有問題(不確定性能是否有影響的情況下),pytorch內(nèi)部是有默認(rèn)初始化參數(shù)的。
Fintune
最后就是精調(diào)了,我們平時(shí)做實(shí)驗(yàn),至少backbone是用預(yù)訓(xùn)練的模型,將其用作特征提取器,或者在它上面做精調(diào)。
用于特征提取的時(shí)候,要求特征提取部分參數(shù)不進(jìn)行學(xué)習(xí),而pytorch提供了requires_grad參數(shù)用于確定是否進(jìn)去梯度計(jì)算,也即是否更新參數(shù)。以下以minist為例,用resnet18作特征提取:
#加載預(yù)訓(xùn)練模型 model = torchvision.models.resnet18(pretrained=True)#遍歷每一個(gè)參數(shù),將其設(shè)置為不更新參數(shù),即不學(xué)習(xí) for param in model.parameters():param.requires_grad = False# 將全連接層改為mnist所需的10類,注意:這樣更改后requires_grad默認(rèn)為True model.fc = nn.Linear(512, 10)# 優(yōu)化 optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)用于全局精調(diào)時(shí),我們一般對(duì)不同的層需要設(shè)置不同的學(xué)習(xí)率,預(yù)訓(xùn)練的層學(xué)習(xí)率小一點(diǎn),其他層大一點(diǎn)。這要怎么做呢?
# 加載預(yù)訓(xùn)練模型 model = torchvision.models.resnet18(pretrained=True) model.fc = nn.Linear(512, 10)# 參考:https://blog.csdn.net/u012759136/article/details/65634477 ignored_params = list(map(id, model.fc.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())# 對(duì)不同參數(shù)設(shè)置不同的學(xué)習(xí)率 params_list = [{'params': base_params, 'lr': 0.001},] params_list.append({'params': model.fc.parameters(), 'lr': 0.01})optimizer = torch.optim.SGD(params_list,0.001,momentum=args.momentum,weight_decay=args.weight_decay)最后整理一下目前,pytorch預(yù)訓(xùn)練的基礎(chǔ)模型:
(1)torchvision
torchvision里面已經(jīng)提供了不同的預(yù)訓(xùn)練模型,一般也夠用了。
pytorch/visiongithub.com正在上傳…重新上傳取消?
包含了alexnet/densenet各種版本(densenet121/densenet169/densenet201/densenet161)/inception_v3/resnet各種版本(resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152')/SqueezeNet各種版本( 'squeezenet1_0', 'squeezenet1_1')/VGG各種版本( 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19')
(2)其他預(yù)訓(xùn)練好的模型,如,SENet/NASNet等。
Cadene/pretrained-models.pytorchgithub.com
(3)gluonCV轉(zhuǎn)pytorch的模型,包括,分類網(wǎng)絡(luò),分割網(wǎng)絡(luò)等,這里的精度均比其他框架高幾個(gè)百分點(diǎn)。
zhanghang1989/gluoncv-torchgithub.com正在上傳…重新上傳取消?
總結(jié)
以上是生活随笔為你收集整理的PyTorch模型读写、参数初始化、Finetune的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: [Pytorch].pth转.pt文件
- 下一篇: [图解]小白都能看懂的FASTER R-