當(dāng)前位置:
首頁 >
【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据
發(fā)布時(shí)間:2024/7/5
37
豆豆
生活随笔
收集整理的這篇文章主要介紹了
【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
1 條件GAN前置知識(shí)
條件GAN也可以使GAN所生成的數(shù)據(jù)可控,使模型變得實(shí)用,
1.1 實(shí)驗(yàn)描述
搭建條件GAN模型,實(shí)現(xiàn)向模型中輸入標(biāo)簽,并使其生成與標(biāo)簽類別對(duì)應(yīng)的模擬數(shù)據(jù)的功能,基于WGAN-gp模型改造實(shí)現(xiàn)帶有條件的wGAN-gp模型。
2?實(shí)例代碼編寫
條件GAN與條件自編碼神經(jīng)網(wǎng)絡(luò)的做法幾乎一樣,在GAN的基礎(chǔ)之上,為每個(gè)模型輸入都添加一個(gè)標(biāo)簽向量。
2.1 代碼實(shí)戰(zhàn):引入模塊并載入樣本----WGAN_cond_237.py(第1部分)
import torch import torchvision from torchvision import transforms from torch.utils.data import DataLoader from torch import nn import torch.autograd as autograd import matplotlib.pyplot as plt import numpy as np import matplotlib import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 1.1 引入模塊并載入樣本:定義基本函數(shù),加載FashionMNIST數(shù)據(jù)集 def to_img(x):x = 0.5 * (x+1)x = x.clamp(0,1)x = x.view(x.size(0),1,28,28)return xdef imshow(img,filename = None):npimg = img.numpy()plt.axis('off')array = np.transpose(npimg,(1,2,0))if filename != None:matplotlib.image.imsave(filename,array)else:plt.imshow(array)# plt.savefig(filename) # 保存圖片 注釋掉,因?yàn)闀?huì)報(bào)錯(cuò),暫時(shí)不知道什么原因 2022.3.26 15:20plt.show()img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])] )data_dir = './fashion_mnist'train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True) train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True) # 測(cè)試數(shù)據(jù)集 val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform) test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False) # 指定設(shè)備 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device)2.2 代碼實(shí)戰(zhàn):實(shí)現(xiàn)生成器和判別器----WGAN_cond_237.py(第2部分)
# 1.2 實(shí)現(xiàn)生成器和判別器 :因?yàn)閺?fù)雜部分都放在loss值的計(jì)算方面了,所以生成器和判別器就會(huì)簡單一些。 # 生成器和判別器各自有兩個(gè)卷積和兩個(gè)全連接層。生成器最終輸出與輸入圖片相同維度的數(shù)據(jù)作為模擬樣本。 # 判別器的輸出不需要有激活函數(shù),并且輸出維度為1的數(shù)值用來表示結(jié)果。 # 在GAN模型中,因判別器的輸入則是具體的樣本數(shù)據(jù),要區(qū)分每個(gè)數(shù)據(jù)的分布特征,所以判別器使用實(shí)例歸一化, class WGAN_D(nn.Module): # 定義判別器類D :有兩個(gè)卷積和兩個(gè)全連接層def __init__(self,inputch=1):super(WGAN_D, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(inputch,64,4,2,1), # 輸出形狀為[batch,64,28,28]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(64,affine=True))self.conv2 = nn.Sequential(nn.Conv2d(64,128,4,2,1),# 輸出形狀為[batch,64,14,14]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(128,affine=True))self.fc = nn.Sequential(nn.Linear(128*7*7,1024),nn.LeakyReLU(0.2,True))self.fc2 = nn.Sequential(nn.InstanceNorm1d(1,affine=True),nn.Flatten(),nn.Linear(1024,1))def forward(self,x,*arg): # 正向傳播x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0),-1)x = self.fc(x)x = x.reshape(x.size(0),1,-1)x = self.fc2(x)return x.view(-1,1).squeeze(1)# 在GAN模型中,因生成器的初始輸入是隨機(jī)值,所以生成器使用批量歸一化。 class WGAN_G(nn.Module): # 定義生成器類G:有兩個(gè)卷積和兩個(gè)全連接層def __init__(self,input_size,input_n=1):super(WGAN_G, self).__init__()self.fc1 = nn.Sequential(nn.Linear(input_size * input_n,1024),nn.ReLU(True),nn.BatchNorm1d(1024))self.fc2 = nn.Sequential(nn.Linear(1024,7*7*128),nn.ReLU(True),nn.BatchNorm1d(7*7*128))self.upsample1 = nn.Sequential(nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,14,14]nn.ReLU(True),nn.BatchNorm2d(64))self.upsample2 = nn.Sequential(nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,28,28]nn.Tanh())def forward(self,x,*arg): # 正向傳播x = self.fc1(x)x = self.fc2(x)x = x.view(x.size(0),128,7,7)x = self.upsample1(x)img = self.upsample2(x)return img2.3?代碼實(shí)戰(zhàn):定義函數(shù)完成梯度懲罰項(xiàng)----WGAN_cond_237.py(第3部分)
# 1.3 定義函數(shù)compute_gradient_penalty()完成梯度懲罰項(xiàng) # 懲罰項(xiàng)的樣本X_inter由一部分Pg分布和一部分Pr分布組成,同時(shí)對(duì)D(X_inter)求梯度,并計(jì)算梯度與1的平方差,最終得到gradient_penalties lambda_gp = 10 # 計(jì)算梯度懲罰項(xiàng) def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):# 獲取一個(gè)隨機(jī)數(shù),作為真假樣本的采樣比例eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)# 按照eps比例生成真假樣本采樣值X_interX_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)d_interpolates = D(X_inter,y_one_hot)fake = torch.full((real_samples.size(0),),1,device=device) # 計(jì)算梯度輸出的掩碼,在本例中需要對(duì)所有梯度進(jìn)行計(jì)算,故需要按照樣本個(gè)數(shù)生成全為1的張量。# 求梯度gradients = autograd.grad(outputs=d_interpolates, # 輸出值outputs,傳入計(jì)算過的張量結(jié)果inputs=X_inter,# 待求梯度的輸入值inputs,傳入可導(dǎo)的張量,即requires_grad=Truegrad_outputs=fake, # 傳出梯度的掩碼grad_outputs,使用1和0組成的掩碼,在計(jì)算梯度之后,會(huì)將求導(dǎo)結(jié)果與該掩碼進(jìn)行相乘得到最終結(jié)果。create_graph=True,retain_graph=True,only_inputs=True)[0]gradients = gradients.view(gradients.size(0),-1)gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gpreturn gradient_penaltys2.4?代碼實(shí)戰(zhàn):定義模型的訓(xùn)練函數(shù)----WGAN_cond_237.py(第4部分)
# 1.4 定義模型的訓(xùn)練函數(shù) # 定義函數(shù)train(),實(shí)現(xiàn)模型的訓(xùn)練過程。 # 在函數(shù)train()中,按照對(duì)抗神經(jīng)網(wǎng)絡(luò)專題(一)中的式(8-24)實(shí)現(xiàn)模型的損失函數(shù)。 # 判別器的loss為D(fake_samples)-D(real_samples)再加上聯(lián)合分布樣本的梯度懲罰項(xiàng)gradient_penalties,其中fake_samples為生成的模擬數(shù)據(jù),real_Samples為真實(shí)數(shù)據(jù), # 生成器的loss為-D(fake_samples)。 def train(D,G,outdir,z_dimension,num_epochs=30):d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定義優(yōu)化器g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)os.makedirs(outdir,exist_ok=True) # 創(chuàng)建輸出文件夾# 在函數(shù)train()中,判別器和生成器是分開訓(xùn)練的。讓判別器學(xué)習(xí)的次數(shù)多一些,判別器每訓(xùn)練5次,生成器優(yōu)化1次。# WGAN_gp不會(huì)因?yàn)榕袆e器準(zhǔn)確率太高而引起生成器梯度消失的問題,所以好的判別器會(huì)讓生成器有更好的模擬效果。for epoch in range(num_epochs):for i,(img,lab) in enumerate(train_loader):num_img = img.size(0)# 訓(xùn)練判別器real_img = img.to(device)y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)for ii in range(5): # 循環(huán)訓(xùn)練5次d_optimizer.zero_grad() # 梯度清零# 對(duì)real_img進(jìn)行判別real_out = D(real_img,y_one_hot)# 生成隨機(jī)值z(mì) = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot) # 生成fake_imgfake_out = D(fake_img,y_one_hot) # 對(duì)fake_img進(jìn)行判別# 計(jì)算梯度懲罰項(xiàng)gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)# 計(jì)算判別器的lossd_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penaltyd_loss.backward()d_optimizer.step()# 訓(xùn)練生成器for ii in range(1): # 訓(xùn)練一次g_optimizer.zero_grad() # 梯度清0z = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot)fake_out = D(fake_img,y_one_hot)g_loss = -torch.mean(fake_out)g_loss.backward()g_optimizer.step()# 輸出可視化結(jié)果,并將生成的結(jié)果以圖片的形式存儲(chǔ)在硬盤中fake_images = to_img(fake_img.cpu().data)real_images = to_img(real_img.cpu().data)rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))# 輸出訓(xùn)練結(jié)果print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))# 保存訓(xùn)練模型torch.save(G.state_dict(), os.path.join(outdir, 'generator.pth'))torch.save(D.state_dict(), os.path.join(outdir, 'discriminator.pth'))2.5?代碼實(shí)戰(zhàn):現(xiàn)可視化模型結(jié)果----WGAN_cond_237.py(第5部分)
# 1.5 定義函數(shù),實(shí)現(xiàn)可視化模型結(jié)果:獲取一部分測(cè)試數(shù)據(jù),顯示由模型生成的模擬數(shù)據(jù)。 def displayAndTest(D,G,z_dimension): # 可視化結(jié)果sample = iter(test_loader)images, labels = sample.next()y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)num_img = images.size(0) # 獲取樣本個(gè)數(shù)with torch.no_grad():z = torch.randn(num_img, z_dimension).to(device) # 生成隨機(jī)數(shù)fake_img = G(z, y_one_hot)fake_images = to_img(fake_img.cpu().data) # 生成模擬樣本rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10))print(labels[:10])2.6?定義判別器類CondWGAN_D----WGAN_cond_237.py
(第6部分)
# 1.6 定義判別器類CondWGAN_D # 在判別器和生成器類的正向結(jié)構(gòu)中,增加標(biāo)簽向量的輸入,并使用全連接網(wǎng)絡(luò)對(duì)標(biāo)簽向量的維度進(jìn)行擴(kuò)展,同時(shí)將其連接到輸入數(shù)據(jù)。 class CondWGAN_D(WGAN_D): # 定義判別器類CondWGAN_D,使其繼承自WGAN_D類。def __init__(self, inputch=2):super(CondWGAN_D, self).__init__(inputch)self.labfc1 = nn.Linear(10, 28 * 28)def forward(self, x, lab): # 添加輸入標(biāo)簽,batch, width, height, channel=1d_in = torch.cat((x.view(x.size(0), -1), self.labfc1(lab)), -1)x = d_in.view(d_in.size(0), 2, 28, 28)return super(CondWGAN_D, self).forward(x, lab)2.7?定義生成器類CondWGAN_G----WGAN_cond_237.py(第7部分)
# 1.7 定義生成器類CondWGAN_G # 在判別器和生成器類的正向結(jié)構(gòu)中,增加標(biāo)簽向量的輸入,并使用全連接網(wǎng)絡(luò)對(duì)標(biāo)簽向量的維度進(jìn)行擴(kuò)展,同時(shí)將其連接到輸入數(shù)據(jù)。 class CondWGAN_G(WGAN_G): # 定義生成器類CondWGAN_G,使其繼承自WGAN_G類。def __init__(self, input_size, input_n=2):super(CondWGAN_G, self).__init__(input_size, input_n)self.labfc1 = nn.Linear(10, input_size)def forward(self, x, lab): # 添加輸入標(biāo)簽,batch, width, height, channel=1d_in = torch.cat((x, self.labfc1(lab)), -1)return super(CondWGAN_G, self).forward(d_in, lab)2.8?調(diào)用函數(shù)并訓(xùn)練模型----WGAN_cond_237.py(第6部分)
# 1.8 調(diào)用函數(shù)并訓(xùn)練模型:實(shí)例化判別器和生成器模型,并調(diào)用函數(shù)進(jìn)行訓(xùn)練 if __name__ == '__main__':z_dimension = 40 # 設(shè)置輸入隨機(jī)數(shù)的維度D = CondWGAN_D().to(device) # 實(shí)例化判別器G = CondWGAN_G(z_dimension).to(device) # 實(shí)例化生成器train(D, G, './condw_img', z_dimension) # 訓(xùn)練模型displayAndTest(D, G, z_dimension) # 輸出可視化在訓(xùn)練之后,模型輸出了可視化結(jié)果,如圖所示,第1行是原始樣本,第2行是輸出的模擬樣本。
同時(shí),程序也輸出了圖8-20中樣本對(duì)應(yīng)的類標(biāo)簽,如下:
? ? tensor([9,2,1,1,6,1,4,6,5,7])
從輸出的樣本中可以看到,輸出的模擬樣本與原始樣本的類別一致,這表明生成器可以按照指定的標(biāo)簽生成模擬數(shù)據(jù)。
?3??代碼匯總(WGAN_cond_237.py)
import torch import torchvision from torchvision import transforms from torch.utils.data import DataLoader from torch import nn import torch.autograd as autograd import matplotlib.pyplot as plt import numpy as np import matplotlib import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 1.1 引入模塊并載入樣本:定義基本函數(shù),加載FashionMNIST數(shù)據(jù)集 def to_img(x):x = 0.5 * (x+1)x = x.clamp(0,1)x = x.view(x.size(0),1,28,28)return xdef imshow(img,filename = None):npimg = img.numpy()plt.axis('off')array = np.transpose(npimg,(1,2,0))if filename != None:matplotlib.image.imsave(filename,array)else:plt.imshow(array)# plt.savefig(filename) # 保存圖片 注釋掉,因?yàn)闀?huì)報(bào)錯(cuò),暫時(shí)不知道什么原因 2022.3.26 15:20plt.show()img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])] )data_dir = './fashion_mnist'train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True) train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True) # 測(cè)試數(shù)據(jù)集 val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform) test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False) # 指定設(shè)備 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device)# 1.2 實(shí)現(xiàn)生成器和判別器 :因?yàn)閺?fù)雜部分都放在loss值的計(jì)算方面了,所以生成器和判別器就會(huì)簡單一些。 # 生成器和判別器各自有兩個(gè)卷積和兩個(gè)全連接層。生成器最終輸出與輸入圖片相同維度的數(shù)據(jù)作為模擬樣本。 # 判別器的輸出不需要有激活函數(shù),并且輸出維度為1的數(shù)值用來表示結(jié)果。 # 在GAN模型中,因判別器的輸入則是具體的樣本數(shù)據(jù),要區(qū)分每個(gè)數(shù)據(jù)的分布特征,所以判別器使用實(shí)例歸一化, class WGAN_D(nn.Module): # 定義判別器類D :有兩個(gè)卷積和兩個(gè)全連接層def __init__(self,inputch=1):super(WGAN_D, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(inputch,64,4,2,1), # 輸出形狀為[batch,64,28,28]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(64,affine=True))self.conv2 = nn.Sequential(nn.Conv2d(64,128,4,2,1),# 輸出形狀為[batch,64,14,14]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(128,affine=True))self.fc = nn.Sequential(nn.Linear(128*7*7,1024),nn.LeakyReLU(0.2,True))self.fc2 = nn.Sequential(nn.InstanceNorm1d(1,affine=True),nn.Flatten(),nn.Linear(1024,1))def forward(self,x,*arg): # 正向傳播x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0),-1)x = self.fc(x)x = x.reshape(x.size(0),1,-1)x = self.fc2(x)return x.view(-1,1).squeeze(1)# 在GAN模型中,因生成器的初始輸入是隨機(jī)值,所以生成器使用批量歸一化。 class WGAN_G(nn.Module): # 定義生成器類G:有兩個(gè)卷積和兩個(gè)全連接層def __init__(self,input_size,input_n=1):super(WGAN_G, self).__init__()self.fc1 = nn.Sequential(nn.Linear(input_size * input_n,1024),nn.ReLU(True),nn.BatchNorm1d(1024))self.fc2 = nn.Sequential(nn.Linear(1024,7*7*128),nn.ReLU(True),nn.BatchNorm1d(7*7*128))self.upsample1 = nn.Sequential(nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,14,14]nn.ReLU(True),nn.BatchNorm2d(64))self.upsample2 = nn.Sequential(nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,28,28]nn.Tanh())def forward(self,x,*arg): # 正向傳播x = self.fc1(x)x = self.fc2(x)x = x.view(x.size(0),128,7,7)x = self.upsample1(x)img = self.upsample2(x)return img# 1.3 定義函數(shù)compute_gradient_penalty()完成梯度懲罰項(xiàng) # 懲罰項(xiàng)的樣本X_inter由一部分Pg分布和一部分Pr分布組成,同時(shí)對(duì)D(X_inter)求梯度,并計(jì)算梯度與1的平方差,最終得到gradient_penalties lambda_gp = 10 # 計(jì)算梯度懲罰項(xiàng) def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):# 獲取一個(gè)隨機(jī)數(shù),作為真假樣本的采樣比例eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)# 按照eps比例生成真假樣本采樣值X_interX_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)d_interpolates = D(X_inter,y_one_hot)fake = torch.full((real_samples.size(0),),1,device=device) # 計(jì)算梯度輸出的掩碼,在本例中需要對(duì)所有梯度進(jìn)行計(jì)算,故需要按照樣本個(gè)數(shù)生成全為1的張量。# 求梯度gradients = autograd.grad(outputs=d_interpolates, # 輸出值outputs,傳入計(jì)算過的張量結(jié)果inputs=X_inter,# 待求梯度的輸入值inputs,傳入可導(dǎo)的張量,即requires_grad=Truegrad_outputs=fake, # 傳出梯度的掩碼grad_outputs,使用1和0組成的掩碼,在計(jì)算梯度之后,會(huì)將求導(dǎo)結(jié)果與該掩碼進(jìn)行相乘得到最終結(jié)果。create_graph=True,retain_graph=True,only_inputs=True)[0]gradients = gradients.view(gradients.size(0),-1)gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gpreturn gradient_penaltys# 1.4 定義模型的訓(xùn)練函數(shù) # 定義函數(shù)train(),實(shí)現(xiàn)模型的訓(xùn)練過程。 # 在函數(shù)train()中,按照對(duì)抗神經(jīng)網(wǎng)絡(luò)專題(一)中的式(8-24)實(shí)現(xiàn)模型的損失函數(shù)。 # 判別器的loss為D(fake_samples)-D(real_samples)再加上聯(lián)合分布樣本的梯度懲罰項(xiàng)gradient_penalties,其中fake_samples為生成的模擬數(shù)據(jù),real_Samples為真實(shí)數(shù)據(jù), # 生成器的loss為-D(fake_samples)。 def train(D,G,outdir,z_dimension,num_epochs=30):d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定義優(yōu)化器g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)os.makedirs(outdir,exist_ok=True) # 創(chuàng)建輸出文件夾# 在函數(shù)train()中,判別器和生成器是分開訓(xùn)練的。讓判別器學(xué)習(xí)的次數(shù)多一些,判別器每訓(xùn)練5次,生成器優(yōu)化1次。# WGAN_gp不會(huì)因?yàn)榕袆e器準(zhǔn)確率太高而引起生成器梯度消失的問題,所以好的判別器會(huì)讓生成器有更好的模擬效果。for epoch in range(num_epochs):for i,(img,lab) in enumerate(train_loader):num_img = img.size(0)# 訓(xùn)練判別器real_img = img.to(device)y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)for ii in range(5): # 循環(huán)訓(xùn)練5次d_optimizer.zero_grad() # 梯度清零# 對(duì)real_img進(jìn)行判別real_out = D(real_img,y_one_hot)# 生成隨機(jī)值z(mì) = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot) # 生成fake_imgfake_out = D(fake_img,y_one_hot) # 對(duì)fake_img進(jìn)行判別# 計(jì)算梯度懲罰項(xiàng)gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)# 計(jì)算判別器的lossd_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penaltyd_loss.backward()d_optimizer.step()# 訓(xùn)練生成器for ii in range(1): # 訓(xùn)練一次g_optimizer.zero_grad() # 梯度清0z = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot)fake_out = D(fake_img,y_one_hot)g_loss = -torch.mean(fake_out)g_loss.backward()g_optimizer.step()# 輸出可視化結(jié)果,并將生成的結(jié)果以圖片的形式存儲(chǔ)在硬盤中fake_images = to_img(fake_img.cpu().data)real_images = to_img(real_img.cpu().data)rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))# 輸出訓(xùn)練結(jié)果print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))# 保存訓(xùn)練模型torch.save(G.state_dict(), os.path.join(outdir, 'cond_generator.pth'))torch.save(D.state_dict(), os.path.join(outdir, 'cond_discriminator.pth'))# 1.5 定義函數(shù),實(shí)現(xiàn)可視化模型結(jié)果:獲取一部分測(cè)試數(shù)據(jù),顯示由模型生成的模擬數(shù)據(jù)。 def displayAndTest(D,G,z_dimension): # 可視化結(jié)果sample = iter(test_loader)images, labels = sample.next()y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)num_img = images.size(0) # 獲取樣本個(gè)數(shù)with torch.no_grad():z = torch.randn(num_img, z_dimension).to(device) # 生成隨機(jī)數(shù)fake_img = G(z, y_one_hot)fake_images = to_img(fake_img.cpu().data) # 生成模擬樣本rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10))print(labels[:10])# 1.6 定義判別器類CondWGAN_D # 在判別器和生成器類的正向結(jié)構(gòu)中,增加標(biāo)簽向量的輸入,并使用全連接網(wǎng)絡(luò)對(duì)標(biāo)簽向量的維度進(jìn)行擴(kuò)展,同時(shí)將其連接到輸入數(shù)據(jù)。 class CondWGAN_D(WGAN_D): # 定義判別器類CondWGAN_D,使其繼承自WGAN_D類。def __init__(self, inputch=2):super(CondWGAN_D, self).__init__(inputch)self.labfc1 = nn.Linear(10, 28 * 28)def forward(self, x, lab): # 添加輸入標(biāo)簽,batch, width, height, channel=1d_in = torch.cat((x.view(x.size(0), -1), self.labfc1(lab)), -1)x = d_in.view(d_in.size(0), 2, 28, 28)return super(CondWGAN_D, self).forward(x, lab)# 1.7 定義生成器類CondWGAN_G # 在判別器和生成器類的正向結(jié)構(gòu)中,增加標(biāo)簽向量的輸入,并使用全連接網(wǎng)絡(luò)對(duì)標(biāo)簽向量的維度進(jìn)行擴(kuò)展,同時(shí)將其連接到輸入數(shù)據(jù)。 class CondWGAN_G(WGAN_G): # 定義生成器類CondWGAN_G,使其繼承自WGAN_G類。def __init__(self, input_size, input_n=2):super(CondWGAN_G, self).__init__(input_size, input_n)self.labfc1 = nn.Linear(10, input_size)def forward(self, x, lab): # 添加輸入標(biāo)簽,batch, width, height, channel=1d_in = torch.cat((x, self.labfc1(lab)), -1)return super(CondWGAN_G, self).forward(d_in, lab)# 1.8 調(diào)用函數(shù)并訓(xùn)練模型:實(shí)例化判別器和生成器模型,并調(diào)用函數(shù)進(jìn)行訓(xùn)練 if __name__ == '__main__':z_dimension = 40 # 設(shè)置輸入隨機(jī)數(shù)的維度D = CondWGAN_D().to(device) # 實(shí)例化判別器G = CondWGAN_G(z_dimension).to(device) # 實(shí)例化生成器train(D, G, './condw_img', z_dimension) # 訓(xùn)練模型displayAndTest(D, G, z_dimension) # 輸出可視化總結(jié)
以上是生活随笔為你收集整理的【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 摩尔斯电码转换python编码_Mors
- 下一篇: 【Pytorch神经网络实战案例】29