【机器学习】小孩都看得懂的 GAN
全文共?6327?字,55?幅圖,
預(yù)計(jì)閱讀時(shí)間?32?分鐘。
本文是「小孩都看得懂」系列的第十八篇,本系列的特點(diǎn)是內(nèi)容不長(zhǎng),碎片時(shí)間完全可以看完,但我背后付出的心血卻不少。喜歡就好!
小孩都看得懂的神經(jīng)網(wǎng)絡(luò)
小孩都看得懂的推薦系統(tǒng)
小孩都看得懂的逐步提升
小孩都看得懂的聚類(lèi)
小孩都看得懂的主成分分析
小孩都看得懂的循環(huán)神經(jīng)網(wǎng)絡(luò)
小孩都看得懂的 Embedding
小孩都看得懂的熵、交叉熵和 KL 散度
小孩都看得懂的?p-value
小孩都看得懂的假設(shè)檢驗(yàn)
小孩都看得懂的基尼不純度
小孩都看得懂的 ROC
小孩都看得懂的 SVD
小孩都看得懂的 SVD 2
小孩都看得懂的 GMM
小孩都看得懂的貝塔分布
小孩都看得懂的多臂老虎機(jī)
小孩都看得懂的 GAN
0
GAN 是什么
GAN 的全稱是 Generative Adversarial Network,中文是生成對(duì)抗網(wǎng)絡(luò)。
一言以蔽之,GAN 包含了兩個(gè)神經(jīng)網(wǎng)絡(luò),生成器(generator)和辨別器(discriminator),兩者互相博弈不斷變強(qiáng),即生成器產(chǎn)出的東西越來(lái)越逼真,辨別器的識(shí)別能力越來(lái)越牛逼。
2
造假和鑒定
生成器和辨別器之間的關(guān)系很像造假者(counterfeiter)和鑒定者(Appraiser)之間的關(guān)系。
造假者不斷造出假貨,目的就是蒙騙鑒定者,在此過(guò)程中其造假能力越來(lái)越高。
鑒定者不斷檢驗(yàn)假貨,目的就是識(shí)破造假者,在此過(guò)程中其鑒定能力越來(lái)越高。
GAN 是造假者的,也是鑒定者的,但歸根結(jié)底還是造假者的。GAN 的最終目標(biāo)是訓(xùn)練出一個(gè)“完美”的造假者,即能讓生成讓鑒定者都蒙圈的產(chǎn)品。
一動(dòng)圖勝千言,下圖展示“造假者如何一步步生成逼真的蒙娜麗莎畫(huà)而最終欺騙了鑒定者”的過(guò)程。
在此過(guò)程中,每當(dāng)造假者生成一幅圖。鑒定者會(huì)給出反饋,造假者從中學(xué)到如何改進(jìn)來(lái)畫(huà)出一張逼真圖。
3
造假鑒定網(wǎng)絡(luò)?
回到神經(jīng)網(wǎng)絡(luò),造假者用生成器來(lái)建模,鑒定者用辨別器來(lái)建模。
根據(jù)上面動(dòng)圖可知,辨別器的任務(wù)是區(qū)分哪些圖片是真實(shí)的,哪些圖片是生成器產(chǎn)生的。
接下來(lái)我們用 Python 創(chuàng)建一個(gè)極簡(jiǎn) GAN。
首先設(shè)置一個(gè)故事背景。
4
故事背景
在傾斜島(slanted island)上,每個(gè)人都是傾斜的,大概像左傾斜 45 度左右。
島主想做人臉生成器,由于島上的人的臉部特征非常簡(jiǎn)單,因此用 2 * 2 像素的模糊人臉圖片。
限于技術(shù),島主只用了個(gè)一層的神經(jīng)網(wǎng)絡(luò)。
但在這個(gè)極度簡(jiǎn)單的設(shè)置下,一層的 GAN 也能生成“傾斜人臉”。
5
辨別人臉
下圖展示四個(gè)人臉的樣子。
從 2*2 像素來(lái)表示人臉,深色代表此處有人臉,淺色代表此處沒(méi)有人臉。
如果不是人臉呢?那么其 2*2 像素圖中的元素就是隨機(jī)的,如下所示。
復(fù)習(xí)一下:
人臉:對(duì)角線上是深色,非對(duì)角線上是淺色
非人臉:任意四處都可能是深色或淺色
像素可以用 0 到 1 的數(shù)值來(lái)表示:
人臉:對(duì)角線上的數(shù)值大,非對(duì)角線上的數(shù)值小
非人臉:任意四處都可能是 0-1 之間的任意數(shù)值
弄清了人臉照片和非人臉照片用不同特征的 2*2 數(shù)值矩陣表示之后,接下來(lái)兩節(jié)我們來(lái)看如何構(gòu)建辨別器(discriminator)和生成器(generator)。
先分析辨別器。
6
辨別器
辨別器就是用來(lái)辨別人臉,那么當(dāng)看到照片的像素值時(shí),如何辨別呢?
簡(jiǎn)單!上節(jié)已經(jīng)分析過(guò):
人臉:對(duì)角線上的數(shù)值大,非對(duì)角線上的數(shù)值小
非人臉:任意四處都可能是 0-1 之間的任意數(shù)值
如果要用一個(gè)數(shù)值表示人臉和非人臉,該用什么樣的操作呢?也簡(jiǎn)單,如下圖所示,加上 (1,1) 位置的元素,減去 (1,2) 位置的元素,減去?(2,1) 位置的元素,加上?(2,2) 位置的元素,得到一個(gè)數(shù)值就可以了。
人臉得到的分?jǐn)?shù)是 2(較大),非人臉得到的分?jǐn)?shù)是 -0.5(較小)。
設(shè)定一個(gè)閾值 1,得分大于 1 是人臉,小于 1 不是人臉。
將上述內(nèi)容用神經(jīng)網(wǎng)絡(luò)來(lái)表示,就成了下面的極簡(jiǎn)辨別器了。注意除了“加減減加”矩陣 4 個(gè)元素之外,最后還加上一個(gè)偏置項(xiàng)(bias)得到最終得分。
辨別器最終要判斷是否是人臉,因此產(chǎn)出是一個(gè)概率,需要用 sigmoid 函數(shù)將得分 1 轉(zhuǎn)化成概率 0.73。給定概率閾值 0.5,由于 0.73 > 0.5,辨別器判斷該圖是人臉。
對(duì)另一張非人臉的圖,用同樣操作,最后算出得分 -0.5,用 sigmoid 函數(shù)轉(zhuǎn)換。給定概率閾值 0.5,由于 0.37 <?0.5,辨別器判斷該圖是人臉。
7
生成器
辨別器目標(biāo)是判斷人臉。而生成器目標(biāo)是生成人臉,那什么樣的矩陣像素是人臉圖呢?簡(jiǎn)單!該規(guī)則被已經(jīng)分析多次了:
人臉:對(duì)角線上的數(shù)值大,非對(duì)角線上的數(shù)值小
非人臉:任意四處都可能是 0-1 之間的任意數(shù)值
現(xiàn)在來(lái)看生成過(guò)程。第一步就是從 0-1 之間隨機(jī)選取一個(gè)數(shù),比如 0.7。
回憶生成器的目的是生成人臉,即要保證最終 2*2 矩陣的對(duì)角線上的像素要大(用粗線表明),而非對(duì)角線上的像素要小(用細(xì)線表明)。
舉例,生成矩陣 (1,1) 位置的值,w = 1, b = 1,計(jì)算的分 wz + b = 1.7。
同理計(jì)算矩陣其他三個(gè)位置的得分。
最后都用 sigmoid 函數(shù)將得分轉(zhuǎn)換一下,確保像素值在 0-1 之間。
注意按上圖這樣給權(quán)重 [1, -1, -1, 1] 和偏置 1,有因?yàn)?z 總是在 0 和 1 之間的一個(gè)正數(shù),這樣的一個(gè)神經(jīng)網(wǎng)絡(luò)(生成器)總可以生成一個(gè)像人臉的 2*2 的像素矩陣。
根據(jù)本節(jié)和上節(jié)的展示,我們已經(jīng)知道什么樣的辨別器可以判斷人臉,什么樣的生成器可以生成好的人臉,即什么樣的 GAN 是個(gè)好 GAN。這些都是由權(quán)重和偏置決定的,接下來(lái)看看它們是怎么訓(xùn)練出來(lái)的。首先復(fù)習(xí)一下誤差函數(shù)(error function)。
8
誤差函數(shù)
通常把正類(lèi)用 1 表示,負(fù)類(lèi)用 0 表示。在本例中人臉是正類(lèi),用 1 表示;非人臉是負(fù)類(lèi),用 0 表示。
當(dāng)標(biāo)簽為 1 時(shí)(人臉),-ln(x)?是一個(gè)好的誤差函數(shù),因?yàn)?/p>
當(dāng)預(yù)測(cè)不準(zhǔn)時(shí)(預(yù)測(cè)非人臉,假設(shè) 0.1),那么誤差應(yīng)該較大,- ln(0.1) 較大。?
當(dāng)預(yù)測(cè)準(zhǔn)時(shí)(預(yù)測(cè)人臉,假設(shè) 0.9),那么誤差應(yīng)該較小,-ln(0.9) 較小。
當(dāng)標(biāo)簽為 0?時(shí)(非人臉),-ln(1-x)?是一個(gè)好的誤差函數(shù)。
當(dāng)預(yù)測(cè)準(zhǔn)時(shí)(預(yù)測(cè)非人臉,假設(shè) 0.1),那么誤差應(yīng)該較小,- ln(1-0.1) 較大。?
當(dāng)預(yù)測(cè)不準(zhǔn)時(shí)(預(yù)測(cè)人臉,假設(shè) 0.9),那么誤差應(yīng)該較大,-ln(1-0.9) 較小。
根據(jù)下面兩張總結(jié)圖再鞏固一下 ln 函數(shù)作為誤差函數(shù)的邏輯。
接下來(lái)就是?GAN 中博弈,即生成器和辨別器放在一起會(huì)發(fā)生什么事情。
9
生成器和辨別器放在一起
復(fù)習(xí)一下兩者的結(jié)構(gòu)。
生成器:輸入是一個(gè) 0-1 之間的隨機(jī)數(shù),輸出是圖片的像素矩陣
辨別器:輸入是圖片像素矩陣,輸出是一個(gè)概率值
下面動(dòng)圖展示了從生成器到辨別器的流程。
因?yàn)樵搱D片是從生成器來(lái)的,不是真實(shí)圖片,因此一個(gè)好的辨別器會(huì)判斷這不是臉,那么使用標(biāo)簽為 0 對(duì)應(yīng)的誤差函數(shù),-ln(1-prediction)。
反過(guò)來(lái),一個(gè)好的生成器想騙過(guò)辨別器,即想讓辨別器判斷這是臉,那么使用標(biāo)簽為 1?對(duì)應(yīng)的誤差函數(shù),-ln(prediction)。
好戲來(lái)了,用 G 表示生成器,D 表示辨別器,那么
G(z) 是生成器的產(chǎn)出,即像素矩陣,它也是辨別器的輸入
D(G(z)) 是辨別器的產(chǎn)出,即概率,又是上面誤差函數(shù)里的 prediction
為了使生成器和辨別器都變強(qiáng),我們希望最小化誤差函數(shù)
????-ln(D(G(z)) - ln(1-D(G(z))
其中 D(G(z)) 就是辨別器的 prediction。
將我們得到的誤差函數(shù)對(duì)比 GAN 論文中的目標(biāo)函數(shù)(下圖),發(fā)現(xiàn)還是有些差別:
解釋如下:
辨別器除了接收生成器產(chǎn)出的圖片 G(z),還會(huì)接收真實(shí)圖片 x,在這時(shí)一個(gè)好的辨別器會(huì)判斷這是臉,那么使用標(biāo)簽為 1?對(duì)應(yīng)的誤差函數(shù),-ln(-prediction)。那么對(duì)于辨別器,需要最小化的誤差函數(shù)是
????-ln(D(x))?-?ln(1-D(G(z))
將負(fù)號(hào)去掉,等價(jià)于最大化
????ln(D(x))?+?ln(1-D(G(z))
這個(gè)不就是 V(D,G) 么?此過(guò)程是固定生成器,來(lái)優(yōu)化辨別器來(lái)識(shí)別假圖片。
V(D, G) 最大化后,在固定辨別器,來(lái)優(yōu)化生成器來(lái)生成以假亂真的圖片。但是生成器的誤差函數(shù)不是 -ln(D(G(z))?嗎?怎么能和 V(D, G) 扯上關(guān)系呢?其實(shí)?-ln(D(G(z)) 等價(jià)于?ln(1-D(G(z)),這時(shí) V(D, G) 的第二項(xiàng),而其第一項(xiàng) ln(D(x))?對(duì)于 G 是個(gè)常數(shù),加不加都無(wú)所謂。
最后 V(D, G) 中的兩項(xiàng)都有期望符號(hào),在實(shí)際優(yōu)化中我們就通過(guò) n 個(gè)樣本的統(tǒng)計(jì)平均值來(lái)實(shí)現(xiàn)。第一項(xiàng)期望中的 x 從真實(shí)數(shù)據(jù)分布 p_data(x) 中來(lái),第一項(xiàng)期望中的 z 從特定概率分布 p_z(z) 中來(lái)。
綜上,先通過(guò) D 最大化 V(D,G) 再通過(guò) G 最小化 V(D, G)。
10
訓(xùn)練 GAN
在訓(xùn)練中,當(dāng)人臉來(lái)自生成器,通過(guò)最小化誤差函數(shù),辨別器輸出概率值接近 0。
當(dāng)人臉來(lái)自真實(shí)圖片,通過(guò)最小化誤差函數(shù),辨別器輸出概率值接近 1。
當(dāng)然所有神經(jīng)網(wǎng)絡(luò)的訓(xùn)練算法都是梯度下降了。
OK,接下來(lái)的內(nèi)容確實(shí)不適合普通小孩了,對(duì)數(shù)學(xué)和編程有強(qiáng)烈興趣的小孩可以繼續(xù)看下去 。
11
數(shù)學(xué)推導(dǎo)
辨別器:從像素矩陣到概率
生成器:從隨機(jī)數(shù) z 到像素矩陣
得到誤差函數(shù)相對(duì)于生成器和辨別器中的權(quán)重和偏置的各種偏導(dǎo)數(shù)后,就可以寫(xiě)代碼實(shí)現(xiàn)了。
12
Python 實(shí)現(xiàn) - 準(zhǔn)備工作
引入 numpy 和 matplotlib。
import numpy as np from numpy import random from matplotlib import pyplot as plt %matplotlib inline編寫(xiě)繪畫(huà)人臉像素的函數(shù)。
def view_samples(samples, m, n):fig, axes = plt.subplots(figsize=(10, 10), nrows=m, ncols=n, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples):ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)im = ax.imshow(1-img.reshape((2,2)), cmap='Greys_r') return fig, axes畫(huà)出四張人臉,注意其像素矩陣中對(duì)角線上的數(shù)值大,非對(duì)角線上的數(shù)值小。
faces = [np.array([1,0,0,1]),np.array([0.9,0.1,0.2,0.8]),np.array([0.9,0.2,0.1,0.8]),np.array([0.8,0.1,0.2,0.9]),np.array([0.8,0.2,0.1,0.9])]_ = view_samples(faces, 1, 4)畫(huà)出二十張非人臉,注意其像素矩陣中的數(shù)都是隨機(jī)的。
noise = [np.random.randn(2,2) for i in range(20)] def generate_random_image():return [np.random.random(), np.random.random(), np.random.random(), np.random.random()]_ = view_samples(noise, 4,5)13
Python 實(shí)現(xiàn) - 構(gòu)建辨別器
首先實(shí)現(xiàn) sigmoid 函數(shù)。
def sigmoid(x):return np.exp(x)/(1.0+np.exp(x))用面向?qū)ο缶幊?#xff08;OOP)來(lái)編寫(xiě)辨別器,代碼如下:
其中
__init__() 是構(gòu)建函數(shù)
forward() 函數(shù)將像素矩陣打平成向量 x,乘上權(quán)重 w 加上偏置 b 得到得分,再通過(guò) sigmoid() 函數(shù)轉(zhuǎn)成概率
error_form_image() 計(jì)算當(dāng)接收真實(shí)圖片為輸入的誤差函數(shù)
error_form_noise()?計(jì)算當(dāng)接收生成器為輸入的誤差函數(shù)
derivatives_form_image()?計(jì)算當(dāng)接收真實(shí)圖片為輸入誤差函數(shù)對(duì)權(quán)重 w 和偏置 b 的偏導(dǎo)數(shù)?
derivatives_form_noise()?計(jì)算當(dāng)接收生成器為輸入誤差函數(shù)對(duì)權(quán)重 w 和偏置 b 的偏導(dǎo)數(shù)?
update_form_image()?計(jì)算當(dāng)接收真實(shí)圖片為輸入時(shí)的梯度下降法
update_form_noise()?計(jì)算當(dāng)接收生成器為輸入時(shí)的梯度下降法
14
Python 實(shí)現(xiàn) - 構(gòu)建生成器
用面向?qū)ο缶幊?#xff08;OOP)來(lái)編寫(xiě)生成器,代碼如下:
其中
__init__() 是構(gòu)建函數(shù)
forward() 函數(shù)將隨機(jī)數(shù)?z 乘上權(quán)重 w 加上偏置 b 得到得分,再通過(guò) sigmoid() 函數(shù)轉(zhuǎn)成像素
error()?計(jì)算當(dāng)固定辨別器為輸入的誤差函數(shù),分兩步:
生成器的 forward() 函數(shù)得到像素
辨別器的 forward() 函數(shù)得到概率
derivatives()?計(jì)算當(dāng)固定辨別器為輸入誤差函數(shù)對(duì)權(quán)重 w 和偏置 b 的偏導(dǎo)數(shù),對(duì)著上一節(jié)數(shù)學(xué)公式看代碼?
update()?計(jì)算當(dāng)固定辨別器為輸入時(shí)的梯度下降法
15
Python 實(shí)現(xiàn) -?訓(xùn)練 GAN
設(shè)定 1000 期(epoch),即將數(shù)據(jù)遍歷 1000 遍開(kāi)始訓(xùn)練,記錄每期生成器和辨別器的誤差。
畫(huà)出生成器和辨別器的誤差函數(shù)圖,發(fā)現(xiàn)生成器逐步趨于穩(wěn)定。
plt.plot(errors_generator) plt.title("Generator error function") plt.legend("gen") plt.show() plt.plot(errors_discriminator) plt.legend('disc') plt.title("Discriminator error function")16
Python 實(shí)現(xiàn) -?結(jié)果展示
生成圖片。
generated_images = [] for i in range(4):z = random.random()generated_image = G.forward(z)generated_images.append(generated_image) _ = view_samples(generated_images, 1, 4) for i in generated_images:print(i)[0.94688171?0.03401213?0.04080795?0.96308679] [0.95653992?0.03437852?0.03579494?0.97063836] [0.95056667?0.03414339?0.03893305?0.96599501] [0.94228203?0.03386046?0.04309146?0.95941292]打印出最終 GAN 的參數(shù),即生成器和辨別器的權(quán)重和偏置。
print("Generator weights", G.weights) print("Generator biases", G.biases) print("Discriminator weights", D.weights) print("Discriminator bias", D.bias)Generator?weights?[ 0.70702123 0.03720449 -0.45703394 0.79375751] Generator?biases?[ 2.48490157 -3.36725912 -2.90139211 2.8172726 ] Discriminator?weights?[ 0.60175083 -0.29127513 -0.40093314 0.37759987] Discriminator?bias?-0.8955103005797729帶有權(quán)重和偏置的 GAN 如下所示。
圖中粗線對(duì)應(yīng)大權(quán)重,細(xì)線對(duì)應(yīng)小或者負(fù)權(quán)重。對(duì)照前面生成器要生成逼真人臉的目標(biāo)來(lái)看(即 2*2 矩陣的對(duì)角線上的值大),是不是這個(gè)權(quán)重很合理。
朋友們,你們弄懂了 GAN 了嗎?
往期精彩回顧適合初學(xué)者入門(mén)人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)筆記專(zhuān)輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專(zhuān)輯 AI基礎(chǔ)下載黃海廣老師《機(jī)器學(xué)習(xí)課程》視頻課黃海廣老師《機(jī)器學(xué)習(xí)課程》711頁(yè)完整版課件
本站qq群955171419,加入微信群請(qǐng)掃碼:
與50位技術(shù)專(zhuān)家面對(duì)面20年技術(shù)見(jiàn)證,附贈(zèng)技術(shù)全景圖總結(jié)
以上是生活随笔為你收集整理的【机器学习】小孩都看得懂的 GAN的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【机器学习】因子分解机(FM) 原理及在
- 下一篇: 实现多个下拉框同一批option,选中其