日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當(dāng)前位置: 首頁 > 编程资源 > 编程问答 >内容正文

编程问答

pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题

發(fā)布時(shí)間:2024/8/23 编程问答 35 豆豆
生活随笔 收集整理的這篇文章主要介紹了 pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题 小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.

首先很多網(wǎng)上的博客,講的都不對(duì),自己跟著他們踩了很多坑

1.單卡訓(xùn)練,單卡加載

這里我為了把三個(gè)模塊save到同一個(gè)文件里,我選擇對(duì)所有的模型先封裝成一個(gè)checkpoint字典,然后保存到同一個(gè)文件里,這樣就可以在加載時(shí)只需要加載一個(gè)參數(shù)文件。
保存:

states = {'state_dict_encoder': encoder.state_dict(),'state_dict_decoder': decoder.state_dict(),} torch.save(states, fname)

加載:

#先初始化模型,因?yàn)楸4鏁r(shí)只保存了模型參數(shù),沒有保存模型整個(gè)結(jié)構(gòu) encoder = Encoder() decoder = Decoder() #然后加載參數(shù) checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置 encoder_state_dict=checkpoint['state_dict_encoder'] decoder_state_dict=checkpoint['state_dict_decoder'] encoder.load_state_dict(encoder_state_dict) decoder.load_state_dict(decoder_state_dict)

2.單卡訓(xùn)練,多卡加載

保存:
保存過程一樣,不做任何改變

states = {'state_dict_encoder': encoder.state_dict(),'state_dict_decoder': decoder.state_dict(),} torch.save(states, fname)

加載:
加載過程也沒有任何改變,但是要注意,先加載模型參數(shù),再對(duì)模型做并行化處理

#先初始化模型,因?yàn)楸4鏁r(shí)只保存了模型參數(shù),沒有保存模型整個(gè)結(jié)構(gòu) encoder = Encoder() decoder = Decoder() #然后加載參數(shù) checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置 encoder_state_dict=checkpoint['state_dict_encoder'] decoder_state_dict=checkpoint['state_dict_decoder'] encoder.load_state_dict(encoder_state_dict) decoder.load_state_dict(decoder_state_dict) #并行處理模型 encoder = nn.DataParallel(encoder) decoder = nn.DataParallel(decoder)

3.多卡訓(xùn)練,單卡加載

注意,如果你考慮到以后可能需要單卡加載你多卡訓(xùn)練的模型,建議在保存模型時(shí),去除模型參數(shù)字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()
保存:

states = {'state_dict_encoder': encoder.module.state_dict(), #不是encoder.state_dict()'state_dict_decoder': decoder.module.state_dict(),} torch.save(states, fname)

加載:
要注意由于我們保存的方式是以單卡的方式保存的,所以還是要先加載模型參數(shù),再對(duì)模型做并行化處理

#先初始化模型,因?yàn)楸4鏁r(shí)只保存了模型參數(shù),沒有保存模型整個(gè)結(jié)構(gòu) encoder = Encoder() decoder = Decoder() #然后加載參數(shù) checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置 encoder_state_dict=checkpoint['state_dict_encoder'] decoder_state_dict=checkpoint['state_dict_decoder'] encoder.load_state_dict(encoder_state_dict) decoder.load_state_dict(decoder_state_dict) #并行處理模型 encoder = nn.DataParallel(encoder) decoder = nn.DataParallel(decoder)

同時(shí),你也可以用第二種方式去保存和加載:

3.多卡訓(xùn)練,單卡加載,方法二

使用model.state_dict()保存,但是單卡加載的時(shí)候,要把模型做并行化(在單卡上并行)
保存:

states = {'state_dict_encoder': encoder.state_dict(), 'state_dict_decoder': decoder.state_dict(),} torch.save(states, fname)

加載:
要注意由于我們保存的方式是以多卡的方式保存的,所以無論你加載之后的模型是在單卡運(yùn)行還是在多卡運(yùn)行,都先把模型并行化再去加載

#先初始化模型,因?yàn)楸4鏁r(shí)只保存了模型參數(shù),沒有保存模型整個(gè)結(jié)構(gòu) encoder = Encoder() decoder = Decoder() #并行處理模型 encoder = nn.DataParallel(encoder) decoder = nn.DataParallel(decoder) #然后加載參數(shù) checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置 encoder_state_dict=checkpoint['state_dict_encoder'] decoder_state_dict=checkpoint['state_dict_decoder'] encoder.load_state_dict(encoder_state_dict) decoder.load_state_dict(decoder_state_dict)

4.多卡保存,多卡加載

這就和多卡保存,單卡加載第二中方式一樣了
**使用model.state_dict()**保存,加載的時(shí)候,要先把模型做并行化(在多卡上并行)
保存:

states = {'state_dict_encoder': encoder.state_dict(), 'state_dict_decoder': decoder.state_dict(),} torch.save(states, fname)

加載:
要注意由于我們保存的方式是以多卡的方式保存的,所以無論你加載之后的模型是在單卡運(yùn)行還是在多卡運(yùn)行,都先把模型并行化再去加載

#先初始化模型,因?yàn)楸4鏁r(shí)只保存了模型參數(shù),沒有保存模型整個(gè)結(jié)構(gòu) encoder = Encoder() decoder = Decoder() #并行處理模型 encoder = nn.DataParallel(encoder) decoder = nn.DataParallel(decoder) #然后加載參數(shù) checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置 encoder_state_dict=checkpoint['state_dict_encoder'] decoder_state_dict=checkpoint['state_dict_decoder'] encoder.load_state_dict(encoder_state_dict) decoder.load_state_dict(decoder_state_dict) 創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎(jiǎng)勵(lì)來咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)

總結(jié)

以上是生活随笔為你收集整理的pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。