生成式对抗网络的原理和实现方法
簡介
- gan全稱:generative adversarial network
- 發明時間:2014年,Ian Goodfellow和Yoshua Bengio的實驗室中相關人員。
- gan的作用:訓練出一個“造假機器人”,造出來的東西跟真的幾乎類似。
- gan的實現原理:如何訓練“造假機器人”?——兩個網絡,一個生成器網絡GGG和一個鑒別器網絡DDD,兩者互相競爭來提升自己。生成器就是“造假機器人”,把造出來的東西丟到鑒別器網絡,鑒別器網絡要鑒別這東西到底來是真實數據還是造假數據。訓練剛開始,生成器生成的東西幾乎是四不像,鑒別器鑒別的能力也幾乎是瞎猜,但訓練正常進行下去,生成器生成的圖像能力和鑒別器鑒別的能力都會上升。雖然從Loss上看,它們一直在波動并難以降低,但它們的能力有時候已經超過了人。(此案例中,生成器Loss和鑒別器Loss有點互斥的感覺,一個低,那么另一個就必然會高,兩者Loss曲線似乎永遠難以同時處于低值。)
使用MNIST手寫數據集介紹gan的全過程
加載環境并下載MNIST數據集
%matplotlib inline
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transformsnum_workers = 0
batch_size = 64transform = transforms.ToTensor()train_data = datasets.MNIST(root='data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,num_workers=num_workers)
可視化數據
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()img = np.squeeze(images[0])fig = plt.figure(figsize = (3,3))
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
定義gan模型
gan由兩個網絡組成:一個鑒別器網絡、一個生成器網絡。網絡結構圖如下:
此案例中,生成器和鑒別器都是用全連接層來搭建:
- 生成器輸入的是一個28x28的隨機矩陣,取值在(-1,1),輸出是一個一維向量,有784個值,并且取值也在(-1,1)之間,因為最后一個全連接層用的tanh激勵函數,輸出值會控制在(-1,1)之間。當然生成器訓練好后,把這個784的向量拉成28x28也就是一張偽造的手寫圖了。
- 鑒定器輸入的也是一個28x28的圖像,可能是生成器捏造出的圖像,也可能是真實MNIST圖像,輸出是一個浮點數。當鑒定器訓練好后,這個float點數大于0,則表示鑒定器認為輸入的圖像是真實的MNIST圖像,小于0,則表示鑒定器認為輸入的圖像是捏造的圖像。
鑒別器的網絡結構代碼
我們希望鑒別器輸出0~1來表示輸入的圖像到底是真實圖像,還是捏造的圖像。
不過:后續我們會為此gan模型選擇 BCEWithLogitsLoss 損失函數,它是sigmoid激勵函數和BCEloss的結合體,所以我們的鑒別器網絡輸出,這里先不需要加sigmoid。
import torch.nn as nn
import torch.nn.functional as Fclass Discriminator(nn.Module):def __init__(self, input_size, hidden_dim, output_size):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_dim*4)self.fc2 = nn.Linear(hidden_dim*4, hidden_dim*2)self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)self.fc4 = nn.Linear(hidden_dim, output_size)self.dropout = nn.Dropout(0.3)def forward(self, x):x = x.view(-1, 28*28)x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)x = self.dropout(x)x = F.leaky_relu(self.fc2(x), 0.2)x = self.dropout(x)x = F.leaky_relu(self.fc3(x), 0.2)x = self.dropout(x)out = self.fc4(x)return out
生成器的網絡結構代碼
class Generator(nn.Module):def __init__(self, input_size, hidden_dim, output_size):super(Generator, self).__init__()self.fc1 = nn.Linear(input_size, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)self.fc3 = nn.Linear(hidden_dim*2, hidden_dim*4)self.fc4 = nn.Linear(hidden_dim*4, output_size)self.dropout = nn.Dropout(0.3)def forward(self, x):x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)x = self.dropout(x)x = F.leaky_relu(self.fc2(x), 0.2)x = self.dropout(x)x = F.leaky_relu(self.fc3(x), 0.2)x = self.dropout(x)out = F.tanh(self.fc4(x))return out
【核心】鑒別器和生成器如何訓練?
它們兩個的訓練其實很簡單,又很機智。兩個網絡是分開訓練的,但是需要同時訓練,因為鑒別器的損失計算需要用到生成器生成的圖像,而生成器的損失計算也需要鑒別器預測的結果。
鑒別器的訓練過程:
- 抽取1張real圖像,鑒定器去判定是真圖還是假圖,計算損失d_real_loss。
- 給生成器輸入一個隨機的28x28的矩陣,生成器網絡生成一個新28x28圖像,把這個fake圖像輸入鑒定器,它去判定是真圖還是假圖,計算損失d_fake_loss。
- 鑒別器本次訓練的總損失:d_loss = d_real_loss + d_fake_loss
- 更新一次鑒別器網絡參數。
生成器的訓練過程:
- (緊接著上述第4步)生成器再次生成1張fake圖,然后把這個fake圖輸入鑒別器網絡,根據鑒別器的結果來計算出生成器本次的損失。
- 更新一次生成器網絡參數。
損失函數
# Calculate losses
# 以下兩個函數,唯一區別是real_loss使用了【標簽平滑】技術。
def real_loss(D_out, smooth=False):batch_size = D_out.size(0)# label smoothingif smooth:# smooth, real labels = 0.9labels = torch.ones(batch_size)*0.9 # 采用【標簽平滑】訓練技巧(因為真實圖像太容易學會,導致過早停止學習)else:labels = torch.ones(batch_size) # real labels = 1# numerically stable losscriterion = nn.BCEWithLogitsLoss()# calculate lossloss = criterion(D_out.squeeze(), labels)return lossdef fake_loss(D_out):batch_size = D_out.size(0)labels = torch.zeros(batch_size) # fake labels = 0criterion = nn.BCEWithLogitsLoss()# calculate lossloss = criterion(D_out.squeeze(), labels)return loss
訓練代碼
import torch.optim as optim
lr = 0.002
d_optimizer = optim.Adam(D.parameters(), lr)
g_optimizer = optim.Adam(G.parameters(), lr)# Discriminator hyperparams
# Size of input image to discriminator (28*28)
input_size = 784
# Size of discriminator output (real or fake)
d_output_size = 1
# Size of last hidden layer in the discriminator
d_hidden_size = 32# Generator hyperparams
# Size of latent vector to give to generator
z_size = 100
# Size of discriminator output (generated image)
g_output_size = 784
# Size of first hidden layer in the generator
g_hidden_size = 32import pickle as pklnum_epochs = 30# keep track of loss and generated, "fake" samples
samples = [] #保存每個epoch后,生成器生成的樣本效果圖。
losses = [] #保存每個epoch的loss值。# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()# train the network
D.train()
G.train()
for epoch in range(num_epochs):for batch_i, (real_images, _) in enumerate(train_loader):batch_size = real_images.size(0)## Important rescaling step ## real_images = real_images*2 - 1 # rescale input images from [0,1) to [-1, 1)# ============================================# TRAIN THE DISCRIMINATOR# ============================================d_optimizer.zero_grad()# 1. Train with real images# Compute the discriminator losses on real images # smooth the real labelsD_real = D(real_images)d_real_loss = real_loss(D_real, smooth=True)# 2. Train with fake images# Generate fake imagesz = np.random.uniform(-1, 1, size=(batch_size, z_size))z = torch.from_numpy(z).float()fake_images = G(z)# Compute the discriminator losses on fake images D_fake = D(fake_images)d_fake_loss = fake_loss(D_fake)# add up loss and perform backpropd_loss = d_real_loss + d_fake_lossd_loss.backward()d_optimizer.step()# =========================================# TRAIN THE GENERATOR# =========================================g_optimizer.zero_grad()# 1. Train with fake images and flipped labels# Generate fake imagesz = np.random.uniform(-1, 1, size=(batch_size, z_size))z = torch.from_numpy(z).float()fake_images = G(z)# Compute the discriminator losses on fake images # using flipped labels!D_fake = D(fake_images)g_loss = real_loss(D_fake) # use real loss to flip labels# perform backpropg_loss.backward()g_optimizer.step()print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))## AFTER EACH EPOCH### append discriminator loss and generator losslosses.append((d_loss.item(), g_loss.item()))#每訓練一個epoch,測試生成器生成圖像的情況,并保存生成的結果# generate and save sample, fake imagesG.eval() # eval mode for generating samplessamples_z = G(fixed_z) samples.append(samples_z)G.train() # back to train mode# Save training generator samples
with open('train_samples.pkl', 'wb') as f: #將生成器每個epoch的生成效果圖保存到pkl文件中。pkl.dump(samples, f)
30個epoch,loss圖如下:
從上圖可看出,loss很難下降,而且波動劇烈。但是實際上,生成器loss和鑒別器loss是一種相反關系,即鑒別器牛逼,那么生成器就很菜,它們loss會一個高一個低,這種情況,生成器就更大幅度的梯度下降,不要多久效果就超過鑒別器,導致它們的loss變反,后面鑒別器又會加速訓練。。。
訓練100個epoch圖也差不多,兩者從loss上并不會收斂:(忽略起始loss)
可視化生成器每個epoch后生成的效果
# Load samples from generator, taken while training
with open('train_samples.pkl', 'rb') as f:samples = pkl.load(f)rows = 30
cols = 16 # 每行顯示幾個生成圖(注意:當初一個epoch只生成了16個樣本,這里最大16)
fig, axes = plt.subplots(figsize=(14,28), nrows=rows, ncols=cols, sharex=True, sharey=True)for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):img = img.detach()ax.imshow(img.reshape((28,28)), cmap='Greys_r')ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)
要知道,輸入生成器的矩陣永遠是隨機的28x28的矩陣,長得像這樣:
從下圖可看出,經過一個epoch后,生成器已經知道要在圖像中間形成一堆‘白色點’,在圖像周圍要‘變黑’。
再經過一些epoch后,開始學會捏造一些數字!
測試生成器效果
# helper function for viewing a list of passed in sample images
def view_samples(epoch, samples):fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples[epoch]):img = img.detach()ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')# randomly generated, new latent vectors
sample_size=16
rand_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
rand_z = torch.from_numpy(rand_z).float()G.eval() # eval mode
# generated samples
rand_images = G(rand_z)# 0 indicates the first set of samples in the passed in list
# and we only have one batch of samples, here
view_samples(0, [rand_images])
總結
以上是生活随笔為你收集整理的生成式对抗网络的原理和实现方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 什么是self-attention、Mu
- 下一篇: GitHub上传代码、更新代码、toke