日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問(wèn) 生活随笔!

生活随笔

當(dāng)前位置: 首頁(yè) > 编程资源 > 编程问答 >内容正文

编程问答

用MXNet实现mnist的生成对抗网络(GAN)

發(fā)布時(shí)間:2024/7/19 编程问答 49 豆豆
生活随笔 收集整理的這篇文章主要介紹了 用MXNet实现mnist的生成对抗网络(GAN) 小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.

用MXNet實(shí)現(xiàn)mnist的生成對(duì)抗網(wǎng)絡(luò)(GAN)

生成式對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network,簡(jiǎn)稱GAN)由一個(gè)生成網(wǎng)絡(luò)與一個(gè)判別網(wǎng)絡(luò)組成。生成網(wǎng)絡(luò)從潛在空間(latent space)中隨機(jī)采樣作為輸入,其輸出結(jié)果需要盡量模仿訓(xùn)練集中的真實(shí)樣本。判別網(wǎng)絡(luò)的輸入則為真實(shí)樣本或生成網(wǎng)絡(luò)的輸出,其目的是將生成網(wǎng)絡(luò)的輸出從真實(shí)樣本中盡可能分辨出來(lái)。而生成網(wǎng)絡(luò)則要盡可能地欺騙判別網(wǎng)絡(luò)。兩個(gè)網(wǎng)絡(luò)相互對(duì)抗、不斷調(diào)整參數(shù),最終目的是使判別網(wǎng)絡(luò)無(wú)法判斷生成網(wǎng)絡(luò)的輸出結(jié)果是否真實(shí)。從數(shù)據(jù)的分布來(lái)看就是使得生成的數(shù)據(jù)分布\(P_z(z)\)與原來(lái)的數(shù)據(jù)\(P_{data}(x)\)十分接近,理想的情況下為\(P_z(z)=P_{data}(x)\)。本文給出了GAN的Loss函數(shù)、說(shuō)明GAN的訓(xùn)練原理,再結(jié)合最簡(jiǎn)單的例子mnist,用MXNet來(lái)實(shí)現(xiàn)GAN。

GAN的基本概念

在一樣樣本中加入一些精心編制的噪聲,會(huì)使得原來(lái)的分類器失效。圖1是一個(gè)廣為流傳的示例,左邊的分類器得到的是熊貓而右邊被分類為了長(zhǎng)臂猿。

圖1 誤分類的示例

為什么會(huì)有這樣的結(jié)果?圖像分類器本質(zhì)上是多維空間中的決策邊界,當(dāng)訓(xùn)練的樣本不足時(shí),可能會(huì)使得分類器過(guò)擬合。當(dāng)向原樣本中加入一些L2范數(shù)很小的噪聲時(shí),人類的視覺(jué)是無(wú)法分別這些細(xì)微的差別,所以依然會(huì)認(rèn)為和原樣本的分類沒(méi)什么區(qū)別。但對(duì)過(guò)擬合的分類器來(lái)說(shuō),輸入樣本的小偏差可能使得最后的決策點(diǎn)越過(guò)了原來(lái)的決策邊界,進(jìn)入到其它分類中了。這就導(dǎo)致了錯(cuò)誤的分類。

對(duì)于生成網(wǎng)絡(luò)設(shè)為G,\(G(Z)\)為生成的對(duì)抗樣本,理想條件下\(G(z)\)隨機(jī)生成的樣本分布與真實(shí)樣本分布是一樣。對(duì)于判別網(wǎng)絡(luò)設(shè)為D,\(D(x)\)為判別樣本是真實(shí)的概率,理想條件下對(duì)真實(shí)樣本有\(G(x)=1\),對(duì)生成樣本有\(D(G(z))=0\)。為了達(dá)到效果,設(shè)計(jì)了如圖2所示的網(wǎng)絡(luò)結(jié)構(gòu):

圖2 GAN的網(wǎng)絡(luò)結(jié)構(gòu)

Loss函數(shù)如下:

\[ V(G,D)=E_{x-p_{data}(x)}[\log(D(x))] + E_{z-p_{z}(z)}[1-\log(D(G(z)))] \tag{1.1} \]

這個(gè)Loss函數(shù)的優(yōu)化方法與EM算法的思想是相似的:在G是固定的情況下,判別網(wǎng)絡(luò)D的精確率越高,那么V就越大;在D固定的條件下,生成網(wǎng)絡(luò)G的生成的樣本越像實(shí)際樣本,那么V就越小。所有V(G,D)進(jìn)行了極小極大化博弈:

\[ \min_G \max_D V(G,D)=E_{x-p_{data}(x)}[\log(D(x))] + E_{z-p_{z}(z)}[1-\log(D(G(z)))] \tag{1.2} \]

實(shí)現(xiàn)mnist的GAN

MXNet的源碼給出了mnsit的GAN實(shí)現(xiàn)(見(jiàn)dcgan.py),但是沒(méi)有給出詳細(xì)的說(shuō)明,我在這里詳細(xì)解釋下,源文件在裝了相關(guān)的python包之后是能正確運(yùn)行的。DCGAN是指Deep Convolution Generative Adversarial Netword(深度卷積生成式對(duì)抗網(wǎng)格)。

mnist的網(wǎng)絡(luò)相對(duì)來(lái)說(shuō)比較簡(jiǎn)單,如圖所示:

圖3 D是判別式網(wǎng)絡(luò),G是生成式網(wǎng)絡(luò),可以看到兩個(gè)網(wǎng)絡(luò)輸出的數(shù)據(jù)大致成反向?qū)ΨQ

生成網(wǎng)絡(luò)G的結(jié)構(gòu)與判別網(wǎng)絡(luò)D的結(jié)果是反向?qū)ΨQ的(雖然兩個(gè)網(wǎng)絡(luò)的開頭或者結(jié)尾有所不同,但這是為了與結(jié)果相對(duì)應(yīng)),這里有一個(gè)很重要但被很多文章忽略的假設(shè):判別網(wǎng)絡(luò)從潛在空間(latent space)是可逆的。不是說(shuō)從最后的結(jié)果是可逆的,但從原始圖片映射到潛在空間這個(gè)過(guò)程(比如說(shuō)從全連接層的n(n一般比較大)維向量)是可逆的,這里說(shuō)的可逆不是嚴(yán)格意義上的反函數(shù),而是從視覺(jué)判別結(jié)果上區(qū)別不大,比如說(shuō)在G與D理想的情況下數(shù)字9通過(guò)判別網(wǎng)絡(luò)得到一個(gè)100維的向量,再將這個(gè)100維向量通過(guò)生成網(wǎng)絡(luò)G得到一張圖片,這張圖片在人類看來(lái)也是9。

代碼實(shí)現(xiàn)如下:

def make_dcgan_sym(ngf, ndf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):BatchNorm = mx.sym.BatchNorm# 生成網(wǎng)絡(luò)G# 輸入生成網(wǎng)絡(luò)G的變量,這個(gè)是潛在空間rand = mx.sym.Variable('rand')g1 = mx.sym.Deconvolution(rand, name='g1', kernel=(4,4), num_filter=ngf*8, no_bias=no_bias)gbn1 = BatchNorm(g1, name='gbn1', fix_gamma=fix_gamma, eps=eps)gact1 = mx.sym.Activation(gbn1, name='gact1', act_type='relu')g2 = mx.sym.Deconvolution(gact1, name='g2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*4, no_bias=no_bias)gbn2 = BatchNorm(g2, name='gbn2', fix_gamma=fix_gamma, eps=eps)gact2 = mx.sym.Activation(gbn2, name='gact2', act_type='relu')g3 = mx.sym.Deconvolution(gact2, name='g3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*2, no_bias=no_bias)gbn3 = BatchNorm(g3, name='gbn3', fix_gamma=fix_gamma, eps=eps)gact3 = mx.sym.Activation(gbn3, name='gact3', act_type='relu')g4 = mx.sym.Deconvolution(gact3, name='g4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf, no_bias=no_bias)gbn4 = BatchNorm(g4, name='gbn4', fix_gamma=fix_gamma, eps=eps)gact4 = mx.sym.Activation(gbn4, name='gact4', act_type='relu')g5 = mx.sym.Deconvolution(gact4, name='g5', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=nc, no_bias=no_bias)# 生成網(wǎng)絡(luò)G最后得到一張相片gout = mx.sym.Activation(g5, name='gact5', act_type='tanh')# 判別網(wǎng)絡(luò)D,這里里的結(jié)構(gòu)與一般的分類網(wǎng)絡(luò)區(qū)別不大data = mx.sym.Variable('data')label = mx.sym.Variable('label')d1 = mx.sym.Convolution(data, name='d1', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf, no_bias=no_bias)dact1 = mx.sym.LeakyReLU(d1, name='dact1', act_type='leaky', slope=0.2)d2 = mx.sym.Convolution(dact1, name='d2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*2, no_bias=no_bias)dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)dact2 = mx.sym.LeakyReLU(dbn2, name='dact2', act_type='leaky', slope=0.2)d3 = mx.sym.Convolution(dact2, name='d3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*4, no_bias=no_bias)dbn3 = BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps)dact3 = mx.sym.LeakyReLU(dbn3, name='dact3', act_type='leaky', slope=0.2)d4 = mx.sym.Convolution(dact3, name='d4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*8, no_bias=no_bias)dbn4 = BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps)dact4 = mx.sym.LeakyReLU(dbn4, name='dact4', act_type='leaky', slope=0.2)d5 = mx.sym.Convolution(dact4, name='d5', kernel=(4,4), num_filter=1, no_bias=no_bias)d5 = mx.sym.Flatten(d5)# 用邏輯回歸計(jì)算最后的lossdloss = mx.sym.LogisticRegressionOutput(data=d5, label=label, name='dloss')# 返回這G與D這兩個(gè)網(wǎng)絡(luò)return gout, dloss

在訓(xùn)練的過(guò)程中,所有的原樣本的label為1,生成網(wǎng)絡(luò)G生成的樣本的label為0,用這樣來(lái)區(qū)別原樣本與生成的對(duì)抗樣本。生成網(wǎng)絡(luò)輸入的潛在空間樣本是100維的,訓(xùn)練過(guò)程如下:

  • 用生成網(wǎng)絡(luò)G生成對(duì)抗樣本gout
  • 對(duì)抗樣本的label設(shè)為0,因?yàn)橐扔眠@個(gè)訓(xùn)練判別網(wǎng)絡(luò)D
  • 用gout來(lái)訓(xùn)練判別網(wǎng)絡(luò)D,得到梯度,但不更新
  • 對(duì)原樣本的label設(shè)為1,再用之來(lái)訓(xùn)練判別網(wǎng)絡(luò)D
  • 得到梯度后合入gout得到的梯度,更新D的參數(shù)
  • 下面的過(guò)程是為了得到生成網(wǎng)絡(luò)G的loss
    • 設(shè)gout的label為1,因?yàn)樯删W(wǎng)絡(luò)G的目標(biāo)就是要生成label為1的樣本,所以訓(xùn)練G的label為1。反之,如果訓(xùn)練D,為了區(qū)別原樣本與生成樣本所以label為0。
    • 用判別網(wǎng)絡(luò)D來(lái)得輸入的梯度dgout,這個(gè)梯度就是生成網(wǎng)絡(luò)G的loss。
  • 用這個(gè)loss反向傳播生成網(wǎng)絡(luò)G,并更新參數(shù)。

這里面的關(guān)鍵就是用判別網(wǎng)絡(luò)D來(lái)得到生成網(wǎng)絡(luò)G的loss,之所以可以這樣,是因?yàn)檫@兩個(gè)網(wǎng)絡(luò)是可逆的。訓(xùn)練的代碼如下:

if __name__ == '__main__':logging.basicConfig(level=logging.DEBUG)# =============setting============dataset = 'mnist'imgnet_path = './train.rec'ndf = 64ngf = 64nc = 3batch_size = 64Z = 100lr = 0.0002beta1 = 0.5ctx = mx.gpu(0)check_point = FalsesymG, symD = make_dcgan_sym(ngf, ndf, nc)#mx.viz.plot_network(symG, shape={'rand': (batch_size, 100, 1, 1)}).view()#mx.viz.plot_network(symD, shape={'data': (batch_size, nc, 64, 64)}).view()# ==============data==============if dataset == 'mnist':X_train, X_test = get_mnist()train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size)elif dataset == 'imagenet':train_iter = ImagenetIter(imgnet_path, batch_size, (3, 64, 64))rand_iter = RandIter(batch_size, Z)label = mx.nd.zeros((batch_size,), ctx=ctx)# =============module G=============modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)modG.bind(data_shapes=rand_iter.provide_data)modG.init_params(initializer=mx.init.Normal(0.02))modG.init_optimizer(optimizer='adam',optimizer_params={'learning_rate': lr,'wd': 0.,'beta1': beta1,})mods = [modG]# =============module D=============modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=('label',), context=ctx)modD.bind(data_shapes=train_iter.provide_data,label_shapes=[('label', (batch_size,))],inputs_need_grad=True)modD.init_params(initializer=mx.init.Normal(0.02))modD.init_optimizer(optimizer='adam',optimizer_params={'learning_rate': lr,'wd': 0.,'beta1': beta1,})mods.append(modD)# ============printing==============def norm_stat(d):return mx.nd.norm(d)/np.sqrt(d.size)mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True)mon = Noneif mon is not None:for mod in mods:passdef facc(label, pred):pred = pred.ravel()label = label.ravel()return ((pred > 0.5) == label).mean()def fentropy(label, pred):pred = pred.ravel()label = label.ravel()return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean()mG = mx.metric.CustomMetric(fentropy)mD = mx.metric.CustomMetric(fentropy)mACC = mx.metric.CustomMetric(facc)print('Training...')stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')# =============train===============for epoch in range(100):train_iter.reset()for t, batch in enumerate(train_iter):rbatch = rand_iter.next()if mon is not None:mon.tic()# 首先生成對(duì)抗樣本modG.forward(rbatch, is_train=True)outG = modG.get_outputs()# update discriminator on fake# 這里的負(fù)樣本label為0,正樣本label為1,不像普遍的mnist一樣。那么modG就想生成樣本label為1的,modD要將modG生成的數(shù)據(jù)判定為0# train_iter(真實(shí)樣本)中的數(shù)據(jù)判定為1。label[:] = 0modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)modD.backward()#modD.update()# 先Copy得到的對(duì)抗樣本的梯度,要注意是復(fù)制不是引用。gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]modD.update_metric(mD, [label])modD.update_metric(mACC, [label])# update discriminator on real# 對(duì)真實(shí)樣本的數(shù)據(jù)訓(xùn)練label[:] = 1batch.label = [label]modD.forward(batch, is_train=True)modD.backward()# 對(duì)抗樣本與真實(shí)樣本的梯度合到一起建行梯度更新for gradsr, gradsf in zip(modD._exec_group.grad_arrays, gradD):for gradr, gradf in zip(gradsr, gradsf):gradr += gradfmodD.update()modD.update_metric(mD, [label])modD.update_metric(mACC, [label])# update generator# 更新modG的參數(shù),這里要注意的是,modG想要生成的樣本label是1的,所以在modD中用了這個(gè)label,就是想生成的樣本向label=1靠近。# 前向和向后生成輸入數(shù)據(jù)的梯度diffDlabel[:] = 1modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)modD.backward()diffD = modD.get_input_grads()# diffD就是modG的loss產(chǎn)生的梯度,用它來(lái)向后傳播并更新參數(shù)。modG.backward(diffD)modG.update()mG.update([label], modD.get_outputs())if mon is not None:mon.toc_print()t += 1if t % 10 == 0:print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get())mACC.reset()mG.reset()mD.reset()visual('gout', outG[0].asnumpy())diff = diffD[0].asnumpy()diff = (diff - diff.mean())/diff.std()visual('diff', diff)visual('data', batch.data[0].asnumpy())if check_point:print('Saving...')modG.save_params('%s_G_%s-%04d.params'%(dataset, stamp, epoch))modD.save_params('%s_D_%s-%04d.params'%(dataset, stamp, epoch))

訓(xùn)練的結(jié)果部分結(jié)果如下,gout是生成的樣本,data是原樣本,diff是它們的差。可以從后面生成的gout中看到,結(jié)果缺少一些數(shù)字,比如2、3等,這是因?yàn)槲覀儧](méi)有對(duì)各個(gè)數(shù)字的潛在空間進(jìn)行生成樣本而是用統(tǒng)一的空間,這個(gè)統(tǒng)一的空間中對(duì)應(yīng)的數(shù)字可能沒(méi)有2、3等或者說(shuō)它們點(diǎn)的比例相對(duì)來(lái)說(shuō)比較小,樣例用到的空間只是保證生成樣本是數(shù)字,但并不保證每個(gè)數(shù)字都會(huì)有,如果我保證生成每個(gè)數(shù)字的樣本,那么得重新設(shè)計(jì)程序,但原理和例程相差不大。

圖4 輸出的圖像結(jié)果:data是原始數(shù)據(jù),gout是G生成的對(duì)搞樣本,diff是兩者的差。

過(guò)程打印的輸出如下:

epoch: 99 iter: 930 metric: ('facc', 1.0) ('fentropy', 8.3449375152587884) ('fentropy', 0.00077932097192388026)

【防止爬蟲轉(zhuǎn)載而導(dǎo)致的格式問(wèn)題——鏈接】:
http://www.cnblogs.com/heguanyou/p/7642608.html

轉(zhuǎn)載于:https://www.cnblogs.com/heguanyou/p/7642608.html

總結(jié)

以上是生活随笔為你收集整理的用MXNet实现mnist的生成对抗网络(GAN)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。

如果覺(jué)得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。