李宏毅机器学习作业6-使用GAN生成动漫人物脸
理論部分參考:?李宏毅機器學(xué)習(xí)——對抗生成網(wǎng)絡(luò)(GAN)_iwill323的博客-CSDN博客
目錄
任務(wù)和數(shù)據(jù)集
評價方法
FID
AFD (Anime face detection) rate
代碼
導(dǎo)包
建立數(shù)據(jù)集
顯示一些圖片
模型設(shè)置
生成器
判別器
權(quán)重初始化
訓(xùn)練
流程
損失函數(shù)
二元分類
discriminator
generator
WGAN
訓(xùn)練函數(shù)
訓(xùn)練
讀取數(shù)據(jù)
Set config
推斷
GAN效果
任務(wù)和數(shù)據(jù)集
1. Input: 隨機數(shù),輸入的維度是(batch size, 特征數(shù))
2. Output: 動漫人物臉
3. Implementation requirement: DCGAN & WGAN & WGAN-GP
4. Target:產(chǎn)生1000動漫人物臉
?
數(shù)據(jù)來自Crypko網(wǎng)站,有71,314個圖像。可以從李宏毅2022機器學(xué)習(xí)HW6解析_機器學(xué)習(xí)手藝人的博客-CSDN博客獲取數(shù)據(jù)
評價方法
FID
將真假圖片送入另一個模型,產(chǎn)生對應(yīng)的特征,計算真假特征的距離
?
AFD (Anime face detection) rate
1. To detect how many anime faces in your submission
2. The higher, the better ?
代碼
導(dǎo)包
# import module import os import glob import random from datetime import datetimeimport torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch import optim from torch.utils.data import Dataset, DataLoader from torch import autograd from torch.autograd import Variableimport matplotlib.pyplot as plt import numpy as np from PIL import Image import logging from tqdm import tqdm# seed setting def same_seeds(seed):# Python built-in random modulerandom.seed(seed)# Numpynp.random.seed(seed)# Torchtorch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = Truesame_seeds(2022) workspace_dir = '../input'建立數(shù)據(jù)集
注意fnames是一個list類型的文件,和原代碼不同,這里使用Image.open()來讀取數(shù)據(jù)
# prepare for CrypkoDatasetclass CrypkoDataset(Dataset):def __init__(self, fnames, transform):self.transform = transformself.fnames = fnamesself.num_samples = len(self.fnames)def __getitem__(self,idx):fname = self.fnames[idx]img = Image.open(fname)img = self.transform(img)return imgdef __len__(self):return self.num_samplesdef get_dataset(root):# glob.glob返回匹配給定通配符的文件列表fnames = glob.glob(os.path.join(root, '*')) # listtransform = transforms.Compose([ transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),])dataset = CrypkoDataset(fnames, transform)return dataset顯示一些圖片
temp_dataset = get_dataset(os.path.join(workspace_dir, 'faces'))images = [temp_dataset[i] for i in range(4)] grid_img = torchvision.utils.make_grid(images, nrow=4) plt.figure(figsize=(10,10)) plt.imshow(grid_img.permute(1, 2, 0)) plt.show()?
模型設(shè)置
生成器
生成器的目的是將輸入向量z映射到真的數(shù)據(jù)空間。這兒我們的數(shù)據(jù)為圖片,意味著我們需要將輸入向量z轉(zhuǎn)換為 3x64x64的RGB圖像。實際操作時,通過一系列的二維轉(zhuǎn)置卷積,每次轉(zhuǎn)置卷積后跟一個二維的batch norm層和一個relu激活層。生成器的輸出接入tanh函數(shù)以便滿足輸出范圍為[?1,1]。值得一提的是,每個轉(zhuǎn)置卷積后面跟一個 batch norm 層,是DCGAN論文的一個主要貢獻。這些網(wǎng)絡(luò)層有助于訓(xùn)練時的梯度計算。
反卷積參考這里:ConvTranspose2d原理,深度網(wǎng)絡(luò)如何進行上采樣?_月下花弄影的博客-CSDN博客
# Generatorclass Generator(nn.Module):"""Input shape: (batch, in_dim)Output shape: (batch, 3, 64, 64)"""def __init__(self, in_dim, feature_dim=64):super().__init__()#input: (batch, 100)self.l1 = nn.Sequential(nn.Linear(in_dim, feature_dim * 8 * 4 * 4, bias=False),nn.BatchNorm1d(feature_dim * 8 * 4 * 4),nn.ReLU())self.l2 = nn.Sequential(self.dconv_bn_relu(feature_dim * 8, feature_dim * 4), #(batch, feature_dim * 16, 8, 8) self.dconv_bn_relu(feature_dim * 4, feature_dim * 2), #(batch, feature_dim * 16, 16, 16) self.dconv_bn_relu(feature_dim * 2, feature_dim), #(batch, feature_dim * 16, 32, 32) )self.l3 = nn.Sequential(nn.ConvTranspose2d(feature_dim, 3, kernel_size=5, stride=2,padding=2, output_padding=1, bias=False),nn.Tanh() )self.apply(weights_init)def dconv_bn_relu(self, in_dim, out_dim):return nn.Sequential(nn.ConvTranspose2d(in_dim, out_dim, kernel_size=5, stride=2,padding=2, output_padding=1, bias=False), #double height and widthnn.BatchNorm2d(out_dim),nn.ReLU(True))def forward(self, x):y = self.l1(x)y = y.view(y.size(0), -1, 4, 4)y = self.l2(y)y = self.l3(y)return y判別器
判別器的輸入為3 *64 *64,輸出為概率(分數(shù)),依次通過卷積層,BN層,LeakyReLU層,最后通過sigmoid函數(shù)輸出得分
WGAN的思路是將discriminator訓(xùn)練為距離函數(shù),所以discriminator不需要最后的非線性sigmoid層
# Discriminator class Discriminator(nn.Module):"""Input shape: (batch, 3, 64, 64)Output shape: (batch)"""def __init__(self, model_type, in_dim, feature_dim=64):super(Discriminator, self).__init__()#input: (batch, 3, 64, 64)"""Remove last sigmoid layer for WGAN"""self.model_type = model_typeself.l1 = nn.Sequential(nn.Conv2d(in_dim, feature_dim, kernel_size=4, stride=2, padding=1), #(batch, 3, 32, 32)nn.LeakyReLU(0.2),self.conv_bn_lrelu(feature_dim, feature_dim * 2), #(batch, 3, 16, 16)self.conv_bn_lrelu(feature_dim * 2, feature_dim * 4), #(batch, 3, 8, 8)self.conv_bn_lrelu(feature_dim * 4, feature_dim * 8), #(batch, 3, 4, 4)nn.Conv2d(feature_dim * 8, 1, kernel_size=4, stride=1, padding=0) ) if self.model_type == 'GAN':self.l1.add_module('sigmoid', nn.Sigmoid() )self.apply(weights_init)def conv_bn_lrelu(self, in_dim, out_dim):layer = nn.Sequential(nn.Conv2d(in_dim, out_dim, 4, 2, 1),nn.BatchNorm2d(out_dim),nn.LeakyReLU(0.2),)if self.model_type == 'WGAN-GP':layer[1] = nn.InstanceNorm2d(out_dim)return layerdef forward(self, x):y = self.l1(x)y = y.view(-1)return y權(quán)重初始化
DCGAN指出,所有的權(quán)重都以均值為0,標準差為0.2的正態(tài)分布隨機初始化。weights_init 函數(shù)讀取一個已初始化的模型并重新初始化卷積層,轉(zhuǎn)置卷積層,batch normalization 層。這個函數(shù)在模型初始化之后使用。
在生成器和判別器的初始化函數(shù)中:self.apply(weights_init)
# setting for weight init function def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:m.weight.data.normal_(0.0, 0.02)elif classname.find('BatchNorm') != -1:m.weight.data.normal_(1.0, 0.02)m.bias.data.fill_(0)訓(xùn)練
流程
- 在訓(xùn)練generator 的時候,要重新將generator生成的fake image送入discriminator,得到判別結(jié)果。因為此時的discriminator已經(jīng)經(jīng)過了訓(xùn)練,generator要騙過這個更新過的discriminator。這就是李宏毅老師上課說過的generator 和 discriminator“砥礪前行”。
損失函數(shù)
二元分類
正如老師在上課時所講的,GAN的訓(xùn)練過程是一個minmax訓(xùn)練,但是幾乎沒有人會真的使用梯度上升的方法,所以實作和理論有出入。GAN的思想和二元分類有一定的關(guān)系,下面先看二元分類問題的損失函數(shù),希望該Loss function越小越好。
當y=1時,L(y^,y)=?log y^。如果y^越接近1,L(y^,y)≈0,表示預(yù)測效果越好;如果y^越接近0,L(y^,y)≈+∞,表示預(yù)測效果越差。
當y=0時,L(y^,y)=?log (1?y^)。如果y^越接近0,L(y^,y)≈0,表示預(yù)測效果越好;如果y^越接近1,L(y^,y)≈+∞,表示預(yù)測效果越差。
discriminator
下面是李老師PPT給出的discriminator損失函數(shù)
套用二元分類的損失函數(shù),讓y^=D(y),當數(shù)據(jù)采集自Pdata時,標簽y=1,損失函數(shù)為?log y^;當數(shù)據(jù)采集自PG時,標簽y=0,損失函數(shù)為?log (1?y^)。將二者相加,其實就是V(G,D)的相反數(shù),也就是說,訓(xùn)練discriminator可以直接使用二元交叉熵損失(BCELoss),其中真實圖片的label為1,生成的圖片的label為0
r_label = torch.ones((bs)).to(self.device) f_label = torch.zeros((bs)).to(self.device) r_loss = self.loss(r_logit, r_label) f_loss = self.loss(f_logit, f_label) loss_D = (r_loss + f_loss) / 2generator
下面是李老師PPT給出的generator損失函數(shù)
?拋去V(G,D)中不相關(guān)的第一項,變成:
D(G(z))的值位于0-1,log(1-D(G(z)))的最小值是負無窮,問題在于,越往負無窮的方向,loss曲線的梯度越大,最后梯度爆炸。所以實際訓(xùn)練時,對于generator的訓(xùn)練不會使用梯度下降和最小化的目標函數(shù)。使用下面的目標函數(shù)來替代原來的生成器損失(?這部分可以參考CS231n課程CS231n 2022PPT筆記- 生成模型Generative Modeling_iwill323的博客-CSDN博客):
套用二元分類的損失函數(shù),讓y^=D(G(z)),讓標簽y=0,則損失函數(shù)為?log (1?y^),所以也可以直接使用二元交叉熵損失(BCELoss),只要指定label為0
loss_G = self.loss(f_logit, r_label)WGAN
損失函數(shù)
loss_D = -torch.mean(r_logit) + torch.mean(f_logit)WGAN-GP參考李宏毅2022機器學(xué)習(xí)HW6解析_機器學(xué)習(xí)手藝人的博客-CSDN博客代碼,但是效果沒有做出來,計算了30個epoch還是生成噪音圖。
訓(xùn)練函數(shù)
class TrainerGAN():def __init__(self, config, device):self.config = config self.model_type = self.config["model_type"]self.device = deviceself.G = Generator(self.config["z_dim"])self.D = Discriminator(self.model_type, 3) # 3代表輸入通道數(shù)self.loss = nn.BCELoss() if self.model_type == 'GAN' or self.model_type == 'WGAN-GP':self.opt_D = torch.optim.Adam(self.D.parameters(), lr=self.config["lr"], betas=(0.5, 0.999))self.opt_G = torch.optim.Adam(self.G.parameters(), lr=self.config["lr"], betas=(0.5, 0.999))elif self.model_type == 'WGAN':self.opt_D = torch.optim.RMSprop(self.D.parameters(), lr=self.config["lr"])self.opt_G = torch.optim.RMSprop(self.G.parameters(), lr=self.config["lr"]) self.dataloader = Noneself.log_dir = os.path.join(self.config["save_dir"], 'logs')self.ckpt_dir = os.path.join(self.config["save_dir"], 'checkpoints')FORMAT = '%(asctime)s - %(levelname)s: %(message)s'logging.basicConfig(level=logging.INFO, format=FORMAT,datefmt='%Y-%m-%d %H:%M')self.steps = 0self.z_samples = torch.randn(100, self.config["z_dim"], requires_grad = True).to(self.device) # 打印100個看看生成的效果def prepare_environment(self):"""Use this funciton to prepare function"""os.makedirs(self.log_dir, exist_ok=True)os.makedirs(self.ckpt_dir, exist_ok=True)# update dir by timetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')self.log_dir = os.path.join(self.log_dir, time+f'_{self.config["model_type"]}')self.ckpt_dir = os.path.join(self.ckpt_dir, time+f'_{self.config["model_type"]}')os.makedirs(self.log_dir)os.makedirs(self.ckpt_dir)# model preparationself.G = self.G.to(self.device)self.D = self.D.to(self.device)self.G.train()self.D.train()def gp(self, r_imgs, f_imgs):"""Implement gradient penalty function"""Tensor = torch.cuda.FloatTensoralpha = Tensor(np.random.random((r_imgs.size(0), 1, 1, 1)))interpolates = (alpha*r_imgs + (1 - alpha)*f_imgs).requires_grad_(True)d_interpolates = self.D(interpolates)fake = Variable(Tensor(r_imgs.shape[0]).fill_(1.0), requires_grad=False)gradients = autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True,)[0]gradients = gradients.view(gradients.size(0), -1)gradient_penalty = ((gradients.norm(1, dim=1) - 1)**2).mean()return gradient_penaltydef train(self, dataloader):"""Use this function to train generator and discriminator"""self.prepare_environment()for e, epoch in enumerate(range(self.config["n_epoch"])):progress_bar = tqdm(dataloader)progress_bar.set_description(f"Epoch {e+1}")for i, data in enumerate(progress_bar):bs = data.size(0) # batch size# *********************# * Train D *# *********************z = torch.randn(bs, self.config["z_dim"]).to(self.device) # z甚至可以在訓(xùn)練前生成固定一個,反復(fù)使用 f_imgs = self.G(z)r_imgs = data.to(self.device) # Discriminator forwardingr_logit = self.D(r_imgs) # 判斷真實圖像f_logit = self.D(f_imgs.detach()) # 判斷生成的假圖像 使用detach()是為了避免對G求導(dǎo)# SETTING DISCRIMINATOR LOSSif self.model_type == 'GAN':r_label = torch.ones((bs)).to(self.device)f_label = torch.zeros((bs)).to(self.device)r_loss = self.loss(r_logit, r_label)f_loss = self.loss(f_logit, f_label)loss_D = (r_loss + f_loss) / 2elif self.model_type == 'WGAN':loss_D = -torch.mean(r_logit) + torch.mean(f_logit)elif self.model_type == 'WGAN-GP':aa = -torch.mean(r_logit) + torch.mean(f_logit)bb = self.gp(r_imgs, f_imgs)loss_D = aa + bb # 最后一項是gradient_penalty# Discriminator backwardingself.D.zero_grad()if self.model_type != 'WGAN-GP':loss_D.backward()else:loss_D.backward(retain_graph=True)self.opt_D.step() # SETTING WEIGHT CLIP:if self.model_type == 'WGAN':for p in self.D.parameters():p.data.clamp_(-self.config["clip_value"], self.config["clip_value"])# *********************# * Train G *# *********************if self.steps % self.config["n_critic"] == 0:# Generator forwarding f_logit = self.D(f_imgs) # f_imgs沒必要再生成一遍if self.model_type == 'GAN': loss_G = self.loss(f_logit, r_label)elif self.model_type == 'WGAN' or self.model_type == 'WGAN-GP':loss_G = -torch.mean(f_logit) # Generator backwardingself.G.zero_grad()loss_G.backward(retain_graph=True)self.opt_G.step() if self.steps % 10 == 0:progress_bar.set_postfix(loss_G=loss_G.item(), loss_D=loss_D.item())print(aa.detach(), bb.detach())self.steps += 1 self.G.eval()# G()最后一層是tanh(), 輸出是-1到1,也就是說,G()的輸出要變成0-1才是圖像f_imgs_sample = (self.G(self.z_samples).data + 1) / 2.0 filename = os.path.join(self.log_dir, f'Epoch_{epoch+1:03d}.jpg')torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)logging.info(f'Save some samples to {filename}.')# Show some images during training.grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)plt.figure(figsize=(10,10))plt.imshow(grid_img.permute(1, 2, 0))plt.show()self.G.train()if (e+1) % 5 == 0 or e == 0:# Save the checkpoints.torch.save(self.G.state_dict(), os.path.join(self.ckpt_dir, f'G_{e}.pth'))torch.save(self.D.state_dict(), os.path.join(self.ckpt_dir, f'D_{e}.pth'))logging.info('Finish training')def inference(self, G_path, n_generate=1000, n_output=30, show=False):"""1. G_path is the path for Generator ckpt2. You can use this function to generate final answer"""self.G.load_state_dict(torch.load(G_path))self.G.to(self.devices[0])self.G.eval()z = torch.randn(n_generate, self.config["z_dim"]).to(self.devices[0])imgs = (self.G(z).data + 1) / 2.0os.makedirs('output', exist_ok=True)for i in range(n_generate):torchvision.utils.save_image(imgs[i], f'output/{i+1}.jpg')if show:row, col = n_output//10 + 1, 10grid_img = torchvision.utils.make_grid(imgs[:n_output].cpu(), nrow=row)plt.figure(figsize=(row, col))plt.imshow(grid_img.permute(1, 2, 0))plt.show()訓(xùn)練
讀取數(shù)據(jù)
# create dataset by the above function batch_size = 512 num_workers = 2 dataset = get_dataset(os.path.join(workspace_dir, 'faces')) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last = True) print('訓(xùn)練集總長度是 {:d}, batch數(shù)量是 {:.2f}'.format(len(dataset), len(dataset)/batch_size))Set config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'DEVICE: {device}')config = {"model_type": "WGAN", "lr": 1e-4,"n_epoch": 60,"n_critic": 5, # 訓(xùn)練一次generator,多訓(xùn)練幾次discriminator,效果更好 n_critic=5意味著訓(xùn)練比是1:5"z_dim": 100,"workspace_dir": workspace_dir, # define in the environment setting"save_dir": workspace_dir,'clip_value': 1 }trainer = TrainerGAN(config, device) trainer.train(dataloader)推斷
# save the 1000 images into ./output folder trainer.inference(f'{workspace_dir}/checkpoints/2022-03-31_15-59-17_GAN/G_0.pth') # you have to modify the path when running this lineGAN效果
下面是GAN產(chǎn)生的圖片,效果挺一般。只是大體運行了一下,再調(diào)一調(diào)能好多了。
?
除了效果差,訓(xùn)練過中可以發(fā)現(xiàn)到了第22個epoch,圖像突然會變差,前一個還是正常的人像(下面gif中暫停的,左上角是紅頭發(fā)的那一幅圖像),下一個epoch突然變壞,根據(jù)李宏毅2022機器學(xué)習(xí)HW6解析_機器學(xué)習(xí)手藝人的博客-CSDN博客,loss_G突然增大,loss_D接近于0,這說明后續(xù)的訓(xùn)練discriminator相對generator表現(xiàn)的太好,這與GAN的訓(xùn)練背道而馳,GAN訓(xùn)練最好的結(jié)果是loss_G小,loss_D大,也就是discriminator無法分辨generator的結(jié)果。
?
還有一個問題是,訓(xùn)練都后面,生成的圖像多樣性變差,具體原因老師上課講過了
?
下面是WGAN生成的圖像,一直到epoch=50都比較穩(wěn)定
關(guān)于計算速度,發(fā)現(xiàn)了一個有意思的事。同樣的超參數(shù):
config = {
??? "model_type": "GAN",
??? "batch_size": 64,
??? "lr": 1e-4,
??? "n_epoch": 10,
??? "n_critic": 1,
??? "z_dim": 100,
??? "workspace_dir": workspace_dir,
}
英偉達3090顯卡的計算時間為428秒,而3080顯卡更快,只需要327秒,不知道為什么
理論部分參考:?李宏毅機器學(xué)習(xí)——對抗生成網(wǎng)絡(luò)(GAN)_iwill323的博客-CSDN博客理解GAN網(wǎng)絡(luò)基本原理_ifreewolf99的博客-CSDN博客?李宏毅機器學(xué)習(xí)——對抗生成網(wǎng)絡(luò)(GAN)_iwill323的博客-CSDN博客
代碼參考:生成對抗網(wǎng)絡(luò)GAN和DCGAN的理解(pytorch+李宏毅老師作業(yè)6) - 富士山上 - 博客園
李宏毅2022機器學(xué)習(xí)HW6解析_機器學(xué)習(xí)手藝人的博客-CSDN博客
總結(jié)
以上是生活随笔為你收集整理的李宏毅机器学习作业6-使用GAN生成动漫人物脸的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 无法访问计算机请检查名称的拼写,wind
- 下一篇: 1.3 一摞烙饼的排序