使用Jittor实现Conditional GAN
Jittor實(shí)現(xiàn)Conditional GAN
Generative Adversarial Nets(GAN)提出了一種新的方法來訓(xùn)練生成模型。然而,GAN對于要生成的圖片缺少控制。Conditional GAN(CGAN)通過添加顯式的條件或標(biāo)簽,來控制生成的圖像。本文講解了CGAN的網(wǎng)絡(luò)結(jié)構(gòu)、損失函數(shù)設(shè)計(jì)、使用CGAN生成一串?dāng)?shù)字、從頭訓(xùn)練CGAN、以及在mnist手寫數(shù)字?jǐn)?shù)據(jù)集上的訓(xùn)練結(jié)果。
CGAN網(wǎng)絡(luò)架構(gòu)
通過在生成器generator和判別器discriminator中添加相同的額外信息y,GAN就可以擴(kuò)展為一個(gè)conditional模型。y可以是任何形式的輔助信息,例如類別標(biāo)簽或者其他形式的數(shù)據(jù)。可以通過將y作為額外輸入層,添加到生成器和判別器來完成條件控制。
在生成器generator中,除了y之外,還額外輸入隨機(jī)一維噪聲z,為結(jié)果生成提供更多靈活性。
損失函數(shù)
GAN的損失函數(shù)
在解釋CGAN的損失函數(shù)之前,首先介紹GAN的損失函數(shù)。下面是GAN的損失函數(shù)設(shè)計(jì)。
對于判別器D,要訓(xùn)練最大化這個(gè)loss。如果D的輸入是來自真實(shí)樣本的數(shù)據(jù)x,則D的輸出D(x)要盡可能地大,log(D(x))也會(huì)盡可能大。如果D的輸入是來自G生成的假圖片G(z),則D的輸出D(G(z))應(yīng)盡可能地小,從而log(1-D(G(z))會(huì)盡可能地大。這樣可以達(dá)到max D的目的。
對于生成器G,要訓(xùn)練最小化這個(gè)loss。對于G生成的假圖片G(z),希望盡可能地騙過D,讓它覺得生成的圖片就是真的圖片,這樣就達(dá)到了G“以假亂真”的目的。那么D的輸出D(G(z))應(yīng)盡可能地大,從而log(1-D(G(z))會(huì)盡可能地小。這樣可以達(dá)到min G的目的。
D和G以這樣的方式聯(lián)合訓(xùn)練,最終達(dá)到G的生成能力越來越強(qiáng),D的判別能力越來越強(qiáng)的目的。
CGAN的損失函數(shù)
下面是CGAN的損失函數(shù)設(shè)計(jì)。
很明顯,CGAN的loss跟GAN的loss的區(qū)別就是多了條件限定y。D(x/y)代表在條件y下,x為真的概率。D(G(z/y))表示在條件y下,G生成的圖片被D判別為真的概率。
Jittor代碼數(shù)字生成
首先,導(dǎo)入需要的包,并且設(shè)置好所需的超參數(shù):
import jittor as jt
from jittor import nn
import numpy as np
import pylab as pl
%matplotlib inline
隱空間向量長度
latent_dim = 100
類別數(shù)量
n_classes = 10
圖片大小
img_size = 32
圖片通道數(shù)量
channels = 1
圖片張量的形狀
img_shape = (channels, img_size, img_size)
第一步,定義生成器G。該生成器輸入兩個(gè)一維向量y和noise,生成一張圖片。
class Generator(nn.Module):
def init(self):
super(Generator, self).init()
self.label_emb = nn.Embedding(n_classes, n_classes)
def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2))return layersself.model = nn.Sequential(*block((latent_dim + n_classes), 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())def execute(self, noise, labels):gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)img = self.model(gen_input)img = img.view((img.shape[0], *img_shape))return img
第二步,定義判別器D。D輸入一張圖片和對應(yīng)的y,輸出是真圖片的概率。
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear((n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 1))
def execute(self, img, labels):d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)validity = self.model(d_in)return validity
第三步,使用CGAN生成一串?dāng)?shù)字。
代碼如下。可以使用訓(xùn)練好的模型來生成圖片,也可以使用提供的預(yù)訓(xùn)練參數(shù): 模型預(yù)訓(xùn)練參數(shù)下載:https://cloud.tsinghua.edu.cn/d/fbe30ae0967942f6991c/。
下載提供的預(yù)訓(xùn)練參數(shù)
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl
生成自定義的數(shù)字:
定義模型
generator = Generator()
discriminator = Discriminator()
generator.eval()
discriminator.eval()
加載參數(shù)
generator.load(’./generator_last.pkl’)
discriminator.load(’./discriminator_last.pkl’)
定義一串?dāng)?shù)字
number = “201962517”
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)
pl.imshow(gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)))
生成結(jié)果如下,測試的完整代碼在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/test.py。
從頭訓(xùn)練Condition GAN
從頭訓(xùn)練 Condition GAN 的完整代碼在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/cgan.py,下載下來看看!
!wget https://raw.githubusercontent.com/Jittor/gan-jittor/master/models/cgan/cgan.py
!python3.7 ./cgan.py --help
選擇合適的batch size,運(yùn)行試試
運(yùn)行命令: !python3.7 ./cgan.py --batch_size 8
下載下來的代碼里面定義損失函數(shù)、數(shù)據(jù)集、優(yōu)化器。損失函數(shù)采用MSELoss、數(shù)據(jù)集采用MNIST、優(yōu)化器采用Adam 如下(此段代碼僅僅用于解釋意圖,不能運(yùn)行,需要運(yùn)行請運(yùn)行完整文件cgan.py):
此段代碼僅僅用于解釋意圖,不能運(yùn)行,需要運(yùn)行請運(yùn)行完整文件cgan.py
Define Loss
adversarial_loss = nn.MSELoss()
Define Model
generator = Generator()
discriminator = Discriminator()
Define Dataloader
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
transform.Resize(opt.img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
模型訓(xùn)練的代碼如下(此段代碼僅僅用于解釋意圖,不能運(yùn)行,需要運(yùn)行請運(yùn)行完整文件cgan.py):
此段代碼僅僅用于解釋意圖,不能運(yùn)行,需要運(yùn)行請運(yùn)行完整文件cgan.py
valid表示真,fake表示假
valid = jt.ones([batch_size, 1]).float32().stop_grad()
fake = jt.zeros([batch_size, 1]).float32().stop_grad()
真實(shí)圖像和對應(yīng)的標(biāo)簽
real_imgs = jt.array(imgs)
labels = jt.array(labels)
#########################################################
訓(xùn)練生成器G
- 希望生成的圖片盡可能地讓D覺得是valid
#########################################################
隨機(jī)向量z和隨機(jī)生成的標(biāo)簽
z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
隨機(jī)向量z和隨機(jī)生成的標(biāo)簽經(jīng)過生成器G生成的圖片,希望判別器能夠認(rèn)為生成的圖片和生成的標(biāo)簽是一致的,以此優(yōu)化生成器G的生成能力。
gen_imgs = generator(z, gen_labels)
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.sync()
optimizer_G.step(g_loss)
#########################################################
訓(xùn)練判別器D
- 盡可能識(shí)別real_imgs為valid
- 盡可能識(shí)別gen_imgs為fake
#########################################################
真實(shí)的圖片和標(biāo)簽經(jīng)過判別器的結(jié)果,要盡可能接近valid。
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)
G生成的圖片和對應(yīng)的標(biāo)簽經(jīng)過判別器的結(jié)果,要盡可能接近fake。
validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.sync()
optimizer_D.step(d_loss)
MNIST數(shù)據(jù)集訓(xùn)練結(jié)果
下面展示了Jittor版CGAN在MNIST數(shù)據(jù)集的訓(xùn)練結(jié)果。下面分別是訓(xùn)練0 epoch和90 epoches的結(jié)果。
總結(jié)
以上是生活随笔為你收集整理的使用Jittor实现Conditional GAN的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 计图(Jittor) 1.1版本:新增骨
- 下一篇: XLearning - 深度学习调度平台