【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)
簡述
之前認真學習了網(wǎng)上的一份,代碼做了很詳細的筆記。
【Gans入門】Pytorch實現(xiàn)Gans代碼詳解【70+代碼】
但是上面的任務只是畫一條在一定區(qū)間下的曲線。
這里對這個進行遷移,到可以進行圖像的生成。
圖像的很多數(shù)據(jù)都沒有,但是突然想到在sklearn上的digits是一個非常簡單的圖片。
這里我想到之前的一份筆記
sklearn學習(一)
這里會使用sklearn自帶的小數(shù)據(jù)來做訓練
目標是讓神經(jīng)網(wǎng)絡自己學會生成數(shù)字。
任務描述
為了讓神經(jīng)網(wǎng)絡操作更簡單。這里的輸入數(shù)據(jù)只會選擇特定數(shù)值的數(shù)字圖片數(shù)據(jù)。然后丟給對抗生成神經(jīng)網(wǎng)絡學習。讓其中的生成器學會如何生成手寫數(shù)字。
下面是選擇用數(shù)值1的生成過程
其實可以發(fā)現(xiàn)其實是有點這樣的感覺了。
下面的這個是讓它學習數(shù)字0的效果
可能是由于數(shù)字0的細節(jié)更粗糙一點,所以,可以發(fā)現(xiàn),我們認為這個0生成的更好。(數(shù)字1和數(shù)字4其實是有點像的,所以會有點問題,還有這是因為圖片像素有點低)
代碼詳解
導入包
- torch,numpy這些都是數(shù)據(jù)處理過程中需要的包
- matplotlib為了畫圖
- sklearn主要是為了它本身帶的數(shù)據(jù)
- random主要是為了選擇標準數(shù)據(jù)更具有隨機性
- os,shutil,imageio這三個庫是為了畫出gif動態(tài)圖
創(chuàng)建臨時文件夾
PNGFILE = './png/' if not os.path.exists(PNGFILE):os.mkdir(PNGFILE) else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)這里會創(chuàng)建一個臨時的文件夾png,會把中途生成的那些圖片都存在這,然后我就可以用這些png來生成gif文件
模型參數(shù)
- BATCH_SIZE這個參數(shù)表示每次用多少的數(shù)據(jù)來進行考量。(數(shù)值多的話模型進化的會稍微快點)
- LR_G跟LR_D表示兩個模型的學習率
- N_IDEAS:啟發(fā)式因子(生成函數(shù)的初始層的節(jié)點數(shù))。因為我們要操作的節(jié)點數(shù)量會特別大(特別是圖像問題,但是如果輸入節(jié)點過于大的話,會需要大量的計算資源。所以用小一點的這個基本夠用就行了)
- target_num :表示的是想要生成的數(shù)字。由于數(shù)據(jù)集中只有(0到9)所以,這里也只能取0到9。
- image_max表示圖片像素點的最大值,這個一開始我用到了,但是后來我修改了代碼之后,就用不到了。
- ART_COMPONENTS:像素點數(shù)量(其實本質(zhì)上跟前一個版本的參考節(jié)點數(shù)都是一樣的)
標準數(shù)據(jù)
這個函數(shù)本質(zhì)上,這個區(qū)間上選BATCH_SIZE個標準數(shù)據(jù)。
但是,random.sample只能輸入的是list所以需要先把data轉(zhuǎn)成list,但是轉(zhuǎn)出來的list又不能直接變成torch中的Tensor,這里需要再轉(zhuǎn)成ndarray,之后再轉(zhuǎn)成Tensor,但是要注意在后面加一個.float()函數(shù)的操作。
構建模型
生成器模型,但是Linear轉(zhuǎn)成的數(shù)據(jù)是有可能有負數(shù)的數(shù)據(jù)的,但是作為圖片肯定是不可以有這樣的數(shù)據(jù)的。因為數(shù)據(jù)一定是需要為大于等于0的數(shù)據(jù)。
所以搭建的這個模型最后一定要加一個ReLU()這樣的類似的,來保證沒有0的情況。
G = nn.Sequential( # Generatornn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideasnn.ReLU(), )D = nn.Sequential( # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(), # tell the probability that the art work is made by artist )構建最優(yōu)化的模型
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D) opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)迭代優(yōu)化
這跟之前的是類似的。
for step in range(10000):artist_paintings = artist_works() # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideasG_paintings = G(G_ideas) # fake painting from G (random ideas)prob_artist0 = D(artist_paintings) # D try to increase this probprob_artist1 = D(G_paintings) # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True) # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward(retain_graph=True)opt_G.step()畫圖并保存
if step % 100 == 0: # plottingplt.cla()tempdata = G_paintings[0].detach().numpy()tempdata = tempdata.reshape((8, 8))plt.imshow(tempdata, cmap=plt.cm.gray_r)# plt.draw()plt.savefig(PNGFILE + '%d.png' % times)filedatalist.append(PNGFILE + '%d.png' % times)times += 1plt.pause(0.01)生成gif
generated_images = [] for png_path in filedatalist:generated_images.append(imageio.imread(png_path)) shutil.rmtree(PNGFILE) imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1)全部代碼
import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt from sklearn import datasets import random import os import shutil import imageioPNGFILE = './png/' if not os.path.exists(PNGFILE):os.mkdir(PNGFILE) else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)# Hyper Parameters BATCH_SIZE = 64 LR_G = 0.00001 # learning rate for generator LR_D = 0.00001 # learning rate for discriminator N_IDEAS = 6 # think of this as number of ideas for generating an art work (Generator) target_num = 0 # target Numberdigits = datasets.load_digits() target = digits.target data = digits.data[target == target_num] image_max = max(data.reshape((-1,))) ART_COMPONENTS = data.shape[-1] # it could be total point G can draw in the canvasdef artist_works(): # painting from the famous artist (real target)return torch.from_numpy(np.array(random.sample(list(data), BATCH_SIZE))).float()G = nn.Sequential( # Generatornn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideasnn.ReLU(), )D = nn.Sequential( # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(), # tell the probability that the art work is made by artist )opt_D = torch.optim.Adam(D.parameters(), lr=LR_D) opt_G = torch.optim.Adam(G.parameters(), lr=LR_G) times = 0filedatalist = []for step in range(10000):artist_paintings = artist_works() # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideasG_paintings = G(G_ideas) # fake painting from G (random ideas)prob_artist0 = D(artist_paintings) # D try to increase this probprob_artist1 = D(G_paintings) # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True) # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward(retain_graph=True)opt_G.step()if step % 100 == 0: # plottingplt.cla()tempdata = G_paintings[0].detach().numpy()tempdata = tempdata.reshape((8, 8))plt.imshow(tempdata, cmap=plt.cm.gray_r)# plt.draw()plt.savefig(PNGFILE + '%d.png' % times)filedatalist.append(PNGFILE + '%d.png' % times)times += 1plt.pause(0.01)generated_images = [] for png_path in filedatalist:generated_images.append(imageio.imread(png_path)) shutil.rmtree(PNGFILE) imageio.mimsave('gan.gif', generated_images, 'GIF', duration=0.1) 《新程序員》:云原生和全面數(shù)字化實踐50位技術專家共同創(chuàng)作,文字、視頻、音頻交互閱讀總結
以上是生活随笔為你收集整理的【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【MPI编程】任意节点数的蝶形求和(高性
- 下一篇: 【Pytorch学习】用pytorch搭