GAN对抗生成网络原始论文理解笔记
文章目錄
- 論文:Generative Adversarial Nets
- 符號(hào)意義
- 生成器(Generator)
- 判別器(Discriminator)
- 生成器和判別器的關(guān)系
- GAN的訓(xùn)練流程簡(jiǎn)述
- 論文中的生成模型和判別模型
- GAN的數(shù)學(xué)理論
- 最大似然估計(jì)轉(zhuǎn)換為最小化KL散度問題
- 定義PGP_GPG?
- 全局最優(yōu)
論文:Generative Adversarial Nets
符號(hào)意義
- G()表示對(duì)生成器功能的一個(gè)封裝函數(shù)
- D()表示對(duì)判別器功能的一個(gè)封裝函數(shù)
- x表示真實(shí)數(shù)據(jù)
- z表示含噪音的數(shù)據(jù)
- x ̄\overline xx表示G(z),將噪音數(shù)據(jù)輸入到生成器得到的結(jié)果
- θgθ_gθg?表示生成器的參數(shù)
- θdθ_dθd?表示判別器的參數(shù)
- Pdata(x)P_{data}(x)Pdata?(x)表示真實(shí)數(shù)據(jù)分布
- PG(x;θ)P_G(x;\theta)PG?(x;θ)表示生成器生成的數(shù)據(jù)分布
生成器(Generator)
狹義的生成器就是輸入一個(gè)向量,通過生成器,輸出一個(gè)高維向量(代表圖片、文字等)
其中輸入向量的每一個(gè)維度都代表一個(gè)特征。如下圖示例:
判別器(Discriminator)
狹義的判別器就是輸入數(shù)據(jù)(生成器產(chǎn)物或者真實(shí)數(shù)據(jù)),通過判別器,輸出一個(gè)標(biāo)量數(shù)值,輸出的數(shù)值越大,則代表這個(gè)數(shù)據(jù)越真實(shí)。如下圖示例(假設(shè)輸出數(shù)值在0-1之間):
生成器和判別器的關(guān)系
結(jié)合圖來理解生成器和判別器的關(guān)系:
- 首先輸入噪音讓生成器v1生成圖片
- 之后輸入不同來源的圖片到判別器v1,由判別器v1來判斷圖片是真實(shí)圖片還是生成器生成的圖片
- 然后為了騙過判別器,生成器v1升級(jí)為v2,再生成新的圖片。
- 再將不同來源的圖片輸入到升級(jí)的判別器v2來判斷圖片是真實(shí)圖片還是生成器生成的圖片
- 依次循環(huán)下去,直到判別器無法區(qū)分圖片來源,也就是生成器產(chǎn)生的圖片真實(shí)度越來越接近真實(shí)圖片的真實(shí)度。
GAN的訓(xùn)練流程簡(jiǎn)述
- 在每個(gè)訓(xùn)練迭代器中:
- 先訓(xùn)練判別器
- 從數(shù)萬張圖片(數(shù)據(jù)集)中采樣出m個(gè)樣本,即{ x1,x2,...,xm{x_1,x_2,...,x_m}x1?,x2?,...,xm?}
- 隨機(jī)從一個(gè)分布(高斯分布或均勻分布)里采樣出有噪音的m個(gè)樣本,即{z1,z2,...,zm{z_1,z_2,...,z_m}z1?,z2?,...,zm?}
- 通過生成器獲得生成的數(shù)據(jù),即{ x ̄1,x ̄2,...,x ̄m{\overline x_1,\overline x_2,...,\overline x_m}x1?,x2?,...,xm?},其中x ̄i\overline{x}_ixi?=G(ziz_izi?)
- 更新判別器的參數(shù)使v ̄\overline vv最大。
- v ̄=1m[logD(xi)+log(1?D(x ̄i))]\overline v=\frac{1}{m}[logD(x_i)+log(1-D(\overline x_i))]v=m1?[logD(xi?)+log(1?D(xi?))]
- 梯度下降更新參數(shù)θdθ_dθd?
- 再訓(xùn)練生成器
- 隨機(jī)從一個(gè)分布(高斯分布或均勻分布)里采樣出有噪音的m個(gè)樣本,即{z1,z2,...,zm{z_1,z_2,...,z_m}z1?,z2?,...,zm?}
- 更新生成器的參數(shù)使v ̄\overline vv最小
- v ̄=1m[logD(xi)+log(1?D(x ̄i))]\overline v=\frac{1}{m}[logD(x_i)+log(1-D(\overline x_i))]v=m1?[logD(xi?)+log(1?D(xi?))]
- 梯度下降更新參數(shù)θgθ_gθg?
- 先訓(xùn)練判別器
當(dāng)訓(xùn)練判別器的時(shí)候,就相當(dāng)于把生成器固定住了,當(dāng)訓(xùn)練生成器的時(shí)候,就相當(dāng)于把判別器固定住了,于是就有對(duì)上述關(guān)于v ̄\overline vv的講解:
對(duì)于判別器,目標(biāo)是提升辨認(rèn)圖片來源的能力,對(duì)真實(shí)圖片輸出大的數(shù)值,所以D(xi)D(x_i)D(xi?)越大越好,D(x ̄i)D(\overline x_i)D(xi?)越小越好,也就是v ̄\overline vv越大越好。
對(duì)于生成器:目的是希望自己生成的圖片越來越真實(shí),也就是要讓D(x ̄i)D(\overline x_i)D(xi?)越大越好,也就是v ̄\overline vv越小越好(另一項(xiàng)當(dāng)成常數(shù)即可)。
論文中的生成模型和判別模型
GAN提出了兩個(gè)模型:
-
生成模型(Generator)
生成模型主要是用來生成數(shù)據(jù)分布,目的是盡量與原數(shù)據(jù)分布接近。 -
判別模型(Discriminator)
判別模型主要是用來判斷樣本是來自真實(shí)分布還是生成模型生成的分布。目的是能夠更加好地區(qū)分哪些樣本來自真實(shí)數(shù)據(jù),哪些樣本來自生成模型的數(shù)據(jù),越真實(shí)的數(shù)據(jù)得到的結(jié)果越大。
用數(shù)學(xué)來表示訓(xùn)練過程中兩模型的變化,如下圖:
綠色線表示真實(shí)數(shù)據(jù)的分布,藍(lán)色線表示生成模型輸出的數(shù)據(jù)分布,紅色線表示判別器(越高就表示給的分?jǐn)?shù)越大)
GAN的數(shù)學(xué)理論
最大似然估計(jì)轉(zhuǎn)換為最小化KL散度問題
真實(shí)數(shù)據(jù)的分布是Pdata(x)P_{data}(x)Pdata?(x) ,我們定義一個(gè)分布PG(xi;θ)P_G(x_i;\theta)PG?(xi?;θ) ,我們想要找到一組參數(shù)θ\thetaθ,使得PG(xi;θ)P_G(x_i;\theta)PG?(xi?;θ)越接近Pdata(x)P_{data}(x)Pdata?(x)越好。比如說,PG(xi;θ)P_G(x_i;\theta)PG?(xi?;θ) 如果是一個(gè)高斯混合模型,那么θ\thetaθ就是均值和方差。
采用極大似然估計(jì)方法,我們從真實(shí)數(shù)據(jù)分布 Pdata(x)P_{data}(x)Pdata?(x)里面取樣 m 個(gè)點(diǎn),x1,x2,...,xm{x_1,x_2,...,x_m}x1?,x2?,...,xm?,根據(jù)給定的參數(shù) θ 我們可以算出某個(gè)x在該分布的概率 PG(xi;θ)P_G(x_i;θ)PG?(xi?;θ),即:
也可以將極大似然估計(jì)等價(jià)于最小化KL散度,我們需要找一個(gè)最大的θ\thetaθ使得PG(xi;θ)P_G(x_i;\theta)PG?(xi?;θ)接近Pdata(x)P_{data}(x)Pdata?(x),就有下列式子:
將其化簡(jiǎn),得:
由于需要最大化概率的θ\thetaθ,也就是可以近似等價(jià)于原分布的期望,可得:
然后再展開成期望定義的形式,并且加減一項(xiàng)常數(shù)項(xiàng)(不含θ\thetaθ),不影響結(jié)果,有:
最終化成了最小化KL散度的形式。
其中KL散度用來衡量?jī)煞N概率分布的相似程度,越小則表示兩種概率分布越接近。形式為:
所以機(jī)器學(xué)習(xí)中的最大似然估計(jì),其實(shí)就是最小化我們要尋找的目標(biāo)分布PGP_GPG?與PdataP_{data}Pdata?的KL散度。
定義PGP_GPG?
如何來定義PGP_GPG?呢?
以前是采用高斯分布來定義的,但是生成的圖片會(huì)很模糊,采用更復(fù)雜的分布的話,最大似然會(huì)沒法計(jì)算。所以就引進(jìn)了Generator來定義PGP_GPG?,如下圖:
全局最優(yōu)
優(yōu)化目標(biāo)是最小化PGP_GPG?與PdataP_{data}Pdata?之間的差異:
雖然我們不知道PGP_GPG?與PdataP_{data}Pdata?的公式,但是我們可以從這兩個(gè)分布中采樣出一些樣本。
對(duì)PGP_GPG?,我們從給定的數(shù)據(jù)集中采樣出一些樣本。(該步驟對(duì)應(yīng)訓(xùn)練判別器流程步驟1)
對(duì)PdataP_{data}Pdata?,我們隨機(jī)采樣出一些向量,經(jīng)過Generator輸出一些圖片。(該步驟對(duì)應(yīng)訓(xùn)練判別器流程步驟2,3)
之后經(jīng)過Discriminator我們就可以計(jì)算PGP_GPG?與PdataP_{data}Pdata?的收斂。Discriminator的目標(biāo)函數(shù)是:
該目標(biāo)函數(shù)對(duì)應(yīng)訓(xùn)練判別器的損失函數(shù),意思是假設(shè)x是從PdataP_{data}Pdata? 里面采樣出來的,那么希望D(x)越大越好。如果是從 PGP_GPG?里面采樣出來的,就希望它的值越小越好。x~PdataP_{data}Pdata?表示該均值的x都來自PdataP_{data}Pdata?分布。
我們的目標(biāo)是讓判別器無法區(qū)分PGP_GPG?與PdataP_{data}Pdata?,也就是讓它沒辦法把V(G,D)調(diào)大。接下來從數(shù)學(xué)上去解釋這個(gè)結(jié)論。
給定生成器,我們要找到能最大化目標(biāo)函數(shù)V(D,G)的D*:
現(xiàn)在我們把積分里面的這一項(xiàng)拿出來看:
PdatalogD(x)+PG(x)log(1?D(x))P_{data}logD(x)+P_G(x)log(1-D(x))Pdata?logD(x)+PG?(x)log(1?D(x))
我們想要找到一組參數(shù)D*,使這一項(xiàng)最大。把式子簡(jiǎn)寫一下,將PdataP_{data}Pdata?用a表示,PGP_GPG?用b表示,得:
f(D)=alog(D)+blog(1?D)f(D)=alog(D)+blog(1-D)f(D)=alog(D)+blog(1?D)
對(duì)D求導(dǎo)得:
df(D)dD=a?1D+b?11?D?(?1)\frac{df(D)}{dD}=a*\frac1D+b*\frac1{1-D}*(-1)dDdf(D)?=a?D1?+b?1?D1??(?1)
另這個(gè)求導(dǎo)結(jié)果為0,得:
D?=aa+bD^*=\frac a{a+b}D?=a+ba?
將a,b代回去,得:
D?=Pdata(x)Pdata(x)+PG(x)D^*=\frac {P_{data}(x)}{P_{data}(x)+P_G(x)}D?=Pdata?(x)+PG?(x)Pdata?(x)?
再將這個(gè)D帶入V(G,D*)中,然后分子分母同時(shí)除以2,之后可以化簡(jiǎn)為JS散度形式(KL散度的變體,解決了KL散度非對(duì)稱的問題),得:
當(dāng)PdataP_{data}Pdata?=PGP_GPG?時(shí),JS散度為0,值為-2log2,達(dá)到最優(yōu)(也就是讓判別器沒辦法把V(G,D)調(diào)大)。這是從正向證明當(dāng)PdataP_{data}Pdata?=PGP_GPG?時(shí)達(dá)到最優(yōu),還需從反向證明才可以得出當(dāng)且僅當(dāng)PdataP_{data}Pdata?=PGP_GPG?才可以達(dá)到的全局最優(yōu)。
反向證明很容易:假設(shè)PdataP_{data}Pdata?=PGP_GPG?,那么D*=12\frac 1221?,再直接代入V(G,D*)即可得到-2log2。
所以,當(dāng)且僅當(dāng)PdataP_{data}Pdata?=PGP_GPG?才可以達(dá)到的全局最優(yōu)。也就是,當(dāng)且僅當(dāng)生成分布等于真實(shí)數(shù)據(jù)分布時(shí),我們?nèi)〉米顑?yōu)生成器。
我們從頭整理一下,我們的目標(biāo)是找到一個(gè)G*,去最小化PdataP_{data}Pdata?,PGP_GPG?的差異,也就是:
G?=argminGDiv(PG,Pdata)G^*=argmin_GDiv(P_G,P_{data})G?=argminG?Div(PG?,Pdata?)
但是這個(gè)差異沒法之間去算,所以就用一個(gè)判別器來計(jì)算這兩個(gè)分布的差異:
D?=argmaxDV(D,G)D^*=argmax_DV(D,G)D?=argmaxD?V(D,G)
所以優(yōu)化目標(biāo)就變?yōu)?#xff1a;
G?=argminGmaxDV(G,D)G^*=argmin_Gmax_DV(G,D)G?=argminG?maxD?V(G,D)
這個(gè)看起來很復(fù)雜,其實(shí)直觀理解一下,如下圖,我們假設(shè)已經(jīng)把生成器固定住了,圖片的曲線表示,紅點(diǎn)表示固定住G后的 maxD(G,D)max_D(G,D)maxD?(G,D) , 也就是 PGP_GPG? 和 PdataP_{data}Pdata? 的差異。而我們的目標(biāo)是最小化這個(gè)差異,所以下圖的三個(gè)網(wǎng)絡(luò)中, G3G_3G3? 是最優(yōu)秀的。
參考的文章:
GAN論文閱讀——原始GAN(基本概念及理論推導(dǎo))
生成對(duì)抗網(wǎng)絡(luò)(GAN) 背后的數(shù)學(xué)理論
總結(jié)
以上是生活随笔為你收集整理的GAN对抗生成网络原始论文理解笔记的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: MSDE使用说明文档
- 下一篇: SDL2笔记