生成对抗网络简介(包含TensorFlow代码示例)【翻译】
- 判別模型 vs. 生成模型
- 示例:近似一維高斯分布
- 提高樣本多樣性
- 最后的思考
- 關于GAN的一些討論
最近,大家對生成模型的興趣又開始出現(OpenAI關于生成模型的案例)。生成模型可以學習如何生成數據,這些數據和我們給定的數據很類似(真實數據)。我們用一個例子來描述這背后的原理,比如,我們希望構建一個模型,可以生成高質量的新聞,那么它必須先學習很多的新聞文章。或者說,模型的內部應當有一個很好的關于新聞文檔的表示方式。我們希望用這個表示方式來幫助我們完成相關的任務,比如通過主題給新聞進行分類。
事實上,訓練這樣的一個模型并不容易,但是最近幾年,此類研究進展很大。其中一個非常有名的模型就是生成對抗網絡(Generative Adversarial Networks, GANs)。Facebook著名的AI研究院院長和深度學習研究專家Yann LeCun,最近將GANs稱為深度學習中最為重要的發展:
“There are many interesting recent development in deep learning…The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.” – Yann LeCun
這篇博客剩下的部分就來詳細描述GAN的形成,并提供一個非常簡單的示例(包含一個TensorFlow代碼),使用GAN來解決一個小問題。
判別模型 vs. 生成模型
GAN是一個非常有趣的想法,它首先由University of Montreal的Ian Goodfellow(現在在OpenAI)在2014年提出的。GAN背后的想法包含兩個競爭性的神經網絡模型。其中一個將噪音作為輸入,并產生一些樣本(生成器)。另一個模型(判別器)則同時接受生成器生成的數據和真實的數據,并分別出它們的來源。這兩個網絡坐連續的博弈,其中生成器會生成的數據應當與真實數據越來越像,而判別器則逐漸具有更好的判別能力。這兩個神經網絡同時訓練,并最終使得生成模型生成的數據與真實數據幾乎沒有差異。
在這里,經常可以看到生成器一般被類比為偽造者嘗試生產假幣,而判別器被當作是警察,嘗試甄別出假幣。這個設定和增強學習有點像,生成器會從判別器那里接受到一個獎勵信號,可以知道它生成的數據是否正確。但是增強學習和GAN最大的區別是我們可以從判別模型到生成模型反向傳播梯度信息。所以生成器知道如何調整參數以更好的生成數據并騙過判別器。
目前,GANs主要運用在自然圖像的建模中。他們可以完成非常棒的圖像生成任務。他們可以生成比其它用極大似然作為訓練目標的模型更銳利的圖像。如下圖,是GANs產生的一些圖像示例:
Generated bedrooms. Source: “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”?https://arxiv.org/abs/1511.06434v2
Generated CIFAR-10 samples. Source: “Improved Techniques for Training GANs”?https://arxiv.org/abs/1606.03498
示例:近似一維高斯分布
為了更好地理解它的工作原理,我們使用一個GAN解決一個簡單的問題——學習近似一個一維高斯分布。這是根據Eric Jang的一篇博客的示例。全部的演示代碼見https://github.com/AYLIEN/gan-intro?。這里我們只關注我們感興趣的代碼片段。
首先,我們創建一個真實的分布,一個均值為4,標準差為0.5的高斯分布。我們可以從這個函數中得到一些該分布的樣本(按值排序)。
這個分布的如下圖所示:
同時,我們也定義一個生成器,輸入為噪音分布(樣本函數與之前類似)。根據Eric Jang的例子,我們也使用分層抽樣方法產生生成器的輸入噪音。樣本首先從一個范圍中均勻抽取,并隨機受到擾動。
我們的生成器和判別器網絡非常簡單。生成器用一個線性轉換,通過一個非線性(a softplus function)傳遞,然后接著另一個線性轉換:
在這個例子中,我們發現判別器必須要比生成器厲害,否則的話它就無法區別出樣本正確的來源。因此,我們使用一個更加深層的網絡,維度很高。除了最終層外,使用tanh非線性。最終層是sigmoid函數。
我們使用TensorFlow圖吧這些片段連接起來。同時,我們也為每層網絡定一個了一個損失函數,目標是生成器能騙過判別器。
我們使用TensorFlow中GradientDescentOptimizer來優化每層網絡。我們應當注意到找出好的優化參數需要調整好參數。
為了訓練好這個模型,我們從數據分布中抽取一部分數據以及噪音分布,并在優化生成器參數和判別器參數之間來回切換。
模型動畫的訓練演示參考(請自備梯子):https://youtu.be/mObnwR-u8pc
在這里我們看到,開始的時候生成器的生成結果與真實數據差異很大。在迭代很多次之后(大約在750次迭代之后)就會接近真實分布了。但是在收斂之前,它一直在輸入分布均值附近優化。最后的訓練結果如下圖所示:
這背后的原理也容易理解。判別器是以單獨的樣本來看真實數據和生成器生成數據的。如果生成器能夠產生真實數據均值附近的數據就能夠騙過判別器。
有很多方法都能解決這個問題。在這個例子中,我們可以增加某種早期停止排序的內容,當兩個分部之間的相似性達到一個閾值的時候就停止訓練。然而,我們很難有一個更加泛化的方法運用在更加復雜的問題上。即便在這個簡單的例子上,也很難保證在早期停止的時候生成器的分布可以達到某種程度。最好的方式是是的判別器可以同時檢測多個樣本。
提高樣本多樣性
根據Tim Salimans和他在OpenAI里的同事最近的工作,GAN的一個主要問題是生成器可能會在某個參數的環境下崩塌,并輸出一個不太好的分布。他們提出了一個解決方案:運行判別器同時查看多個樣本,稱為minibatch discrimination。
在這篇文章中,minibatch discrimination判別被定義成任何一個方法,其判別器能夠同時檢測所有的樣本,并決定哪些是從生成器生成的,哪些是真實樣本。他們也提出了一個具體的方法,可以為一個給定的樣本和其他樣本之間的距離進行批建模。這些距離然后與原始的樣本聯結起來,并傳遞給判別器,因此,它在分類的時候即使用了樣本也使用了距離。
這個方法可以簡單地總結如下:
- 取出判別器中間層的某些輸出。
- 通過一個3D張量相乘得到一個矩陣(下面代碼中的of size num_kernels x kernel_dim)。
- 計算這個矩陣中行間的L1距離,并應用在一個負指數上。
- 一個樣本的minibatch特征是這些指數距離的總和。
- 使用新創建的minibatch特征把原始輸入和這個minibatch層聯結起來,并把這個傳給判別器的下一個輸入。
在TensorFlow中,可以變成如下形式:
新的訓練過程如下(請自備梯子):https://youtu.be/0r3g7-4bMYU
顯然,加了minibatch之后生成器的分布更加寬了。收斂后如下圖所示:
最后一點,batch大小比超參數更加重要。在我們的例子中,我們設置的比較小(小于16附近)。
最后的思考
生成對抗網絡給了我們一個全新的角度來做無監督的學習。GANs的大多數成功的應用都在圖像識別領域,但是我們正在把研究拓展到自然語言處理中。其中一個重要的問題是如何評價這些模型。在圖像識別中我們可以通過看生成的圖片來確定這些模型的好壞,盡管這不是一個好的方法。在文本領域,這沒什么用處。在基于極大似然的訓練模型中,我們可以基于似然產生未觀測數據的度量,但是這并不在這里合適。從產生的樣本中,產生基于核密度估計的GAN論文在這里有一些。但是在高維數據中并不合適。另一個解決方案是基于一些接下來的任務做評價(如分類)。
關于GAN的一些討論
最后,我們提供一些關于GAN的討論:
(本文原文)An introduction to Generative Adversarial Networks (with code in TensorFlow)
Ian Goodfellow關于GAN在NLP任務中應用困難的解釋
從對抗樣本出發解釋GAN
知乎關于GAN的最新發展的討論
Ian Goodfellow在NIPS2016上作得關于GAN的匯報
七月在線關于上述匯報的翻譯
國立臺灣大學李宏毅老師關于GAN的中文課程
原文地址:http://www.datalearner.com/blog/1051494816250033
與50位技術專家面對面20年技術見證,附贈技術全景圖總結
以上是生活随笔為你收集整理的生成对抗网络简介(包含TensorFlow代码示例)【翻译】的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: AI芯片怎么降功耗?从ISSCC2017
- 下一篇: JNI实现源码分析【二 数据结构】