TensorFlow反向传播算法实现
TensorFlow反向傳播算法實(shí)現(xiàn)
反向傳播(BPN)算法是神經(jīng)網(wǎng)絡(luò)中研究最多、使用最多的算法之一,用于將輸出層中的誤差傳播到隱藏層的神經(jīng)元,然后用于更新權(quán)重。
學(xué)習(xí) BPN 算法可以分成以下兩個(gè)過程:
- 正向傳播:輸入被饋送到網(wǎng)絡(luò),信號(hào)從輸入層通過隱藏層傳播到輸出層。在輸出層,計(jì)算誤差和損失函數(shù)。
- 反向傳播:在反向傳播中,首先計(jì)算輸出層神經(jīng)元損失函數(shù)的梯度,然后計(jì)算隱藏層神經(jīng)元損失函數(shù)的梯度。接下來用梯度更新權(quán)重。
這兩個(gè)過程重復(fù)迭代直到收斂。
前期準(zhǔn)備
首先給網(wǎng)絡(luò)提供 M 個(gè)訓(xùn)練對(duì)(X,Y),X 為輸入,Y 為期望的輸出。輸入通過激活函數(shù) g(h) 和隱藏層傳播到輸出層。輸出 Yhat 是網(wǎng)絡(luò)的輸出,得到 error=Y-Yhat。其損失函數(shù) J(W) 如下:
其中,i 取遍所有輸出層的神經(jīng)元(1 到 N)。然后可以使用 J(W) 的梯度并使用鏈?zhǔn)椒▌t求導(dǎo),來計(jì)算連接第 i 個(gè)輸出層神經(jīng)元到第 j 個(gè)隱藏層神經(jīng)元的權(quán)重 Wij 的變化:
這里,Oj 是隱藏層神經(jīng)元的輸出,h 表示隱藏層的輸入值。這很容易理解,但現(xiàn)在怎么更新連接第 n 個(gè)隱藏層的神經(jīng)元 k 到第 n+1 個(gè)隱藏層的神經(jīng)元 j 的權(quán)值 Wjk?過程是相同的:將使用損失函數(shù)的梯度和鏈?zhǔn)椒▌t求導(dǎo),但這次計(jì)算 Wjk:
現(xiàn)在已經(jīng)有方程了,看看如何在 TensorFlow 中做到這一點(diǎn)。在這里,還是使用 MNIST 數(shù)據(jù)集(http://yann.lecun.com/exdb/MNIST/)。
具體實(shí)現(xiàn)過程
現(xiàn)在開始使用反向傳播算法:
-
導(dǎo)入模塊:
-
加載數(shù)據(jù)集,通過設(shè)置 one_hot=True 來使用獨(dú)熱編碼標(biāo)簽:
-
定義超參數(shù)和其他常量。這里,每個(gè)手寫數(shù)字的尺寸是 28×28=784 像素。數(shù)據(jù)集被分為 10 類,以 0 到 9 之間的數(shù)字表示。這兩點(diǎn)是固定的。學(xué)習(xí)率、最大迭代周期數(shù)、每次批量訓(xùn)練的批量大小以及隱藏層中的神經(jīng)元數(shù)量都是超參數(shù)??梢酝ㄟ^調(diào)整這些超參數(shù),看看是如何影響網(wǎng)絡(luò)表現(xiàn)的:
-
需要 Sigmoid 函數(shù)的導(dǎo)數(shù)來進(jìn)行權(quán)重更新,所以定義:
-
為訓(xùn)練數(shù)據(jù)創(chuàng)建占位符:
-
創(chuàng)建模型:
-
定義權(quán)重和偏置變量:
-
為正向傳播、誤差、梯度和更新計(jì)算創(chuàng)建計(jì)算圖:
-
定義計(jì)算精度 accuracy 的操作:
-
初始化變量:
-
執(zhí)行圖:
-
結(jié)果如下:
解讀分析
在這里,訓(xùn)練網(wǎng)絡(luò)時(shí)的批量大小為 10,如果增加批量的值,網(wǎng)絡(luò)性能就會(huì)下降。另外,需要在測(cè)試數(shù)據(jù)集上檢測(cè)訓(xùn)練好的網(wǎng)絡(luò)的精度,這里測(cè)試數(shù)據(jù)集的大小是 1000。
單隱藏層多層感知機(jī)在訓(xùn)練數(shù)據(jù)集上的準(zhǔn)確率為 84.45,在測(cè)試數(shù)據(jù)集上的準(zhǔn)確率為 92.1。這是好的,但不夠好。MNIST 數(shù)據(jù)集被用作機(jī)器學(xué)習(xí)中分類問題的基準(zhǔn)。接下來,看一下如何使用 TensorFlow 的內(nèi)置優(yōu)化器影響網(wǎng)絡(luò)性能。
總結(jié)
以上是生活随笔為你收集整理的TensorFlow反向传播算法实现的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TensorFlow单层感知机实现
- 下一篇: TensorFlow实现多层感知机MIN