nn.Dataparallel pytorch 平行计算的两种方法
生活随笔
收集整理的這篇文章主要介紹了
nn.Dataparallel pytorch 平行计算的两种方法
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
1. nn.Dataparallel
多GPU加速訓(xùn)練
原理:
模型分別復(fù)制到每個(gè)卡中,然后把輸入切片,分別放入每個(gè)卡中計(jì)算,然后再用第一塊卡進(jìn)行匯總求loss,反向傳播更新參數(shù)。
第一塊卡占用的內(nèi)存多一點(diǎn),因?yàn)閛utput loss每次都會(huì)在第一塊GPU相加計(jì)算,這就造成了第一塊GPU的負(fù)載遠(yuǎn)遠(yuǎn)大于剩余其他的顯卡。
要求:
batch_size > GPU 數(shù)量
第一種方法:
os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' device_ids = [0,1,2,3] net = torch.nn. Dataparallel(net, device_ids =device_ids) net = net.cuda()第二種方法
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2" if torch.cuda.is_available():self.device = "cuda"if torch.cuda.device_count() > 1:self.G = nn.DataParallel(self.G)self.D_A = nn.DataParallel(self.D_A)self.D_B = nn.DataParallel(self.D_B)self.vgg = nn.DataParallel(self.vgg)self.criterionHis = nn.DataParallel(self.criterionHis)self.criterionGAN = nn.DataParallel(self.criterionGAN)self.criterionL1 = nn.DataParallel(self.criterionL1)self.criterionL2 = nn.DataParallel(self.criterionL2)self.criterionGAN = nn.DataParallel(self.criterionGAN)self.G.cuda()self.vgg.cuda()self.criterionHis.cuda()self.criterionGAN.cuda()self.criterionL1.cuda()self.criterionL2.cuda()self.D_A.cuda()self.D_B.cuda()2.模型分別單獨(dú)放入每個(gè)指定的GPU中
把模型分別放到指定的GPU中,然后在運(yùn)算的過程中,需要把利用**.to(cuda:x)** 去轉(zhuǎn)移數(shù)據(jù)。這樣暫用的內(nèi)存比平行計(jì)算小。但是配置復(fù)雜一點(diǎn)。
vgg_encoder = VGGEncoder().to('cuda:0')attn=CoAttention(channel=512).to('cuda:1')decoder = Decoder().to('cuda:2')optimizer_decoder = Adam(decoder.parameters(), lr=args.learning_rate)optimizer_attn = Adam(attn.parameters(), lr=args.learning_rate)content = content.cuda() # 默認(rèn)的是cuda:0style = style.cuda()content_features = vgg_encoder(content, output_last_feature=True)style_features = vgg_encoder(style, output_last_feature=True)content_features, style_features=attn(content_features.to('cuda:1'),style_features.to('cuda:1')) # 因?yàn)閍ttn在cuda:1中總結(jié)
以上是生活随笔為你收集整理的nn.Dataparallel pytorch 平行计算的两种方法的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python缩进教学_Python缩进和
- 下一篇: Pytorch RuntimeERROR