基于GAN的动漫头像生成
GAN的原理
GAN是一種典型的生成網絡模型,它類似于編解碼結構,通過訓練,他能夠生成不同于訓練集的各種圖片。
首先先訓練判別器,把真圖通過判別器的輸出和真標簽作損失,把假圖通過判別器的輸出和假標簽作損失,讓它具備判別真圖和假圖的能力。然后再訓練生成器,把生成器生成的假圖通過判別器的輸出和真標簽作損失。經過反復的訓練,讓判別器難以分辨生成圖的真假,也就是讓它判別為真或為假的概率各為0.5
數據集下載
網上下載的動漫頭像數據集有很多不清晰的奇異樣本,對此我做了清洗,剩下的都是符合標準的,可直接下載
百度網盤:https://pan.baidu.com/s/1–zFrJdg1gtW2wJ6wtWQsQ
密碼:bu55
網絡結構
生成網絡
相當于一個編碼器
class NetD(nn.Module):# 構建一個判別器,相當與一個二分類問題, 生成一個值def __init__(self):super(NetD, self).__init__()ndf = opt.ndfself.main = nn.Sequential(# 輸入96*96*3nn.Conv2d(3, ndf, 5, 3, 1, bias=False),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 輸入32*32*ndfnn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, True),# 輸入16*16*ndf*2nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, True),# 輸入為8*8*ndf*4nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, True),# 輸入為4*4*ndf*8nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True),nn.Sigmoid() # 分類問題)def forward(self, x):return self.main(x).view(-1)生成器
相當于一個解碼器
class NetG(nn.Module):# 定義一個生成模型,通過輸入噪聲來產生一張圖片def __init__(self):super(NetG, self).__init__()ngf = opt.ngfself.main = nn.Sequential(# 假定輸入為一張1*1*opt.nz維的數據(opt.nz維的向量)nn.ConvTranspose2d(opt.nz , ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(inplace=True),# 輸入一個4*4*ngf*8nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 輸入一個8*8*ngf*4nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=True),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 輸入一個16*16*ngf*2nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(inplace=True),# 輸入一個32*32*ngfnn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),nn.Tanh()# 輸出一張96*96*3)def forward(self, x):return self.main(x)GAN網絡結構設計要點
1、在D網絡中用stride卷積(stride>1)代替pooling層,在G網絡中用conv2d_transpose代替上采樣層
2、在G和D網絡中直接將BN應用到所有層會導致樣本震蕩和模型不穩定,通過在G網絡輸出層和D網絡輸入層不采用BN層可以有效防止這種現象
3、不使用全連接層作為輸出
4、G網絡中除了輸出層用tanh激活,其他層都是用ReLu激活
5、D網絡中都使用LeakyReLu激活
網絡模型訓練
訓練細節
1、預處理環節,將圖像scale到tanh的[-1,1]
2、所有的參數初始化由(0,0.02)的正態分布中隨機得到
3、LeakyReLu的斜率是0.2(默認)
4、優化器Adam的learning rate=0.0002,momentum參數betas的beta1從0.9降為0.5,beta2默認,防止震蕩和不穩定
5、可以G網絡訓練1次,然后D網絡訓練1次,如此反復;也可以G網絡先訓練幾次后,D網絡再訓練1次,如此反復。前者效果出得較快,后者較慢。
訓練代碼
效果展示
生成網絡隨機生成的頭像
總結
以上是生活随笔為你收集整理的基于GAN的动漫头像生成的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 破解握手包
- 下一篇: ssh 工具 socket 10106