TensorFlow2-网络训练技巧
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow2-网络训练技巧
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
TensorFlow2網(wǎng)絡(luò)訓(xùn)練技巧
文章目錄
- TensorFlow2網(wǎng)絡(luò)訓(xùn)練技巧
- 簡介
- 過擬合與欠擬合
- 過擬合問題
- 動量(Momentum)SGD
- 學(xué)習率衰減(learning rate decay)
- 補充說明
簡介
- 在神經(jīng)網(wǎng)絡(luò)這種端到端模型的訓(xùn)練過程中,主要的關(guān)注點實際上并不多,參數(shù)初始化、激活函數(shù)、損失函數(shù)、優(yōu)化方法等,但是過深的模型不免會帶來欠擬合和過擬合的問題,為了解決過擬合帶來的問題,采用了諸如數(shù)據(jù)增強、參數(shù)正則化、Dropout、Batch Normalization、劃分驗證集(交叉驗證)等方法,這些被稱為訓(xùn)練技巧(trick)。
- 當然,為了應(yīng)對訓(xùn)練速度慢的問題,有時候也采用一些特殊的訓(xùn)練技巧于優(yōu)化器上,如加入動量的SGD等。
- 上述的訓(xùn)練trick在TensorFlow2中都提供了簡介高效的API接口,使用時直接調(diào)用這些接口即可很方便的控制訓(xùn)練、可視化訓(xùn)練、數(shù)據(jù)增廣等。
過擬合與欠擬合
- 模型的表達能力(Model Capacity)是多變的。一元線性回歸模型的表達能力很弱,它只能擬合線性分布的數(shù)據(jù);神經(jīng)網(wǎng)絡(luò)的表達能力很強,參數(shù)量龐大,可以擬合非常復(fù)雜的分布。
- 深度學(xué)習中的模型都是層次非常深的,參數(shù)極其復(fù)雜。要訓(xùn)練這樣的網(wǎng)絡(luò)是比較困難的,需要大量的數(shù)據(jù)用于參數(shù)的學(xué)習調(diào)整。數(shù)據(jù)量過少,網(wǎng)絡(luò)難以被充分訓(xùn)練,無法達到擬合訓(xùn)練集分布的效果,這種問題是機器學(xué)習中常見的欠擬合問題(underfitting),該情況下模型的表達能力不夠,模型復(fù)雜度(estimated)小于數(shù)據(jù)真實復(fù)雜度(ground-truth)。還有另一種情況,模型復(fù)雜度(estimated)大于數(shù)據(jù)真實復(fù)雜度(ground-truth),這是因為訓(xùn)練后期模型為了降低loss過分擬合訓(xùn)練數(shù)據(jù),從而導(dǎo)致擬合程度過高,模型失去泛化能力。這種問題在機器學(xué)習中稱為過擬合問題(overfitting)。過擬合現(xiàn)象在訓(xùn)練可視化過程中的表現(xiàn)為隨著訓(xùn)練輪次增加,訓(xùn)練集損失不斷減少,驗證集損失先減少后增加。
- 現(xiàn)代機器學(xué)習中,神經(jīng)網(wǎng)絡(luò)這樣的模型深度很深,模型的表達能力很強,常出現(xiàn)的問題是過擬合問題,欠擬合問題已經(jīng)較少出現(xiàn)。
過擬合問題
- 檢測
- 劃分數(shù)據(jù)集(Splitting)
- 將有標注的訓(xùn)練數(shù)據(jù)拿出小部分劃分為驗證集,驗證集由于包含標簽數(shù)據(jù),可以利用訓(xùn)練好的模型進行預(yù)測得到相關(guān)metrics(如accuracy等),用于檢測模型的訓(xùn)練情況(包括是否過擬合)。
- 可以對tensor進行直接劃分(該方法不會隨機打亂數(shù)據(jù)集)。訓(xùn)練時可以直接將驗證集作為參數(shù)傳入,在驗證集上評測的指標與training設(shè)定的metric相同。驗證集用于訓(xùn)練過程控制訓(xùn)練,模型最終應(yīng)用在測試集上。
- k折交叉驗證(K-fold cross validation)
- 之前的劃分是一次性劃分,有較大概率無法利用所有數(shù)據(jù),因為只能使用劃分的訓(xùn)練集進行訓(xùn)練,而不能使用驗證集進行訓(xùn)練,這部分驗證集數(shù)據(jù)信息就被放棄了。k折交叉驗證是為了充分利用數(shù)據(jù)性能,多次進行數(shù)據(jù)集劃分,要求每次劃分的那部分驗證集(如20%)在后面的4折中會作為訓(xùn)練集,這樣5折下來,每一部分數(shù)據(jù)都被作為訓(xùn)練數(shù)據(jù)過(5折交叉驗證是最常見的,更少不能充分驗證,更多訓(xùn)練量過大)。
- 在TensorFlow2中可以自行實現(xiàn)k折劃分。當然,TensorFlow2也提供比較簡單的接口,只需要指出劃分比例,則會自動劃分出驗證集。
- 劃分數(shù)據(jù)集(Splitting)
- 處理
- 充分的數(shù)據(jù)
- 充分的數(shù)據(jù)可以有效訓(xùn)練網(wǎng)絡(luò),要求網(wǎng)絡(luò)進行更多學(xué)習,減輕過擬合。正是ImageNet這樣的大規(guī)模標注數(shù)據(jù)集,深度學(xué)習的發(fā)展才會如此迅速。
- 降低模型復(fù)雜度
- 正則化(Regularization)方法通過在loss函數(shù)中添加懲罰項,迫使參數(shù)的范數(shù)趨近于0,從而使得低復(fù)雜的參數(shù)較大,使得復(fù)雜網(wǎng)絡(luò)退化。在有些地方,該方法也叫作weight decay(參數(shù)衰減)。
- keras模塊下的layer中參數(shù)正則化非常簡單,只要傳入正則化方法對象即可。實際使用中,TensorFlow2提供更靈活的方法。
- Dropout方法
- 一個簡單粗暴的防止過擬合的方法,以一定概率關(guān)閉神經(jīng)元連接,迫使網(wǎng)絡(luò)學(xué)習更多。
- 在TensorFlow2中,keras的layers模塊將Dropout操作封裝為一層的操作,通過堆疊使用。但是注意,添加了Dropout的網(wǎng)絡(luò)前向傳播是必須制定training參數(shù),因為測試集上預(yù)測是不應(yīng)該斷開連接,這只是訓(xùn)練時的技巧。network = Sequential([layers.Dense(256, activation='relu'),layers.Dropout(0.5), # 0.5 rate to droplayers.Dense(128, activation='relu'),layers.Dropout(0.5), # 0.5 rate to droplayers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
- 數(shù)據(jù)增廣
- Early Stopping
- 早停是防止過擬合的一種常用手段,當訓(xùn)練時驗證集metric已經(jīng)飽和或者開始變壞達到指定次數(shù)時,停止訓(xùn)練。
- 通過keras的callbacks模塊可以很方便實現(xiàn)這個功能。es = keras.callbacks.EarlyStopping(monitor='val_acc', patience=5)
- 充分的數(shù)據(jù)
動量(Momentum)SGD
- 梯度更新的方向不僅僅依賴于當前梯度的方向,而且依賴于上一次梯度的方向(可以理解為慣性)。
- 通過添加動量項,可以使得梯度下降算法找到更好的局部最優(yōu)解或者全局最優(yōu)解。但是,有時候動量SGD有可能花費更多的時間找到不是很好的解。
- 在TensorFlow2中,動量項的梯度更新不需要人為完成,只需要指定動量超參數(shù)權(quán)值,其余交由優(yōu)化器完成即可。很多優(yōu)化算法如Adam是默認使用momentum策略的,不需要人為指定。其中,指定動量項權(quán)值為0.9是一個常用策略。
學(xué)習率衰減(learning rate decay)
- 訓(xùn)練后期,過大的學(xué)習率可能導(dǎo)致不斷波動,難以優(yōu)化。此時采用學(xué)習率衰減策略會是一個不錯的方法,該策略后期會自動調(diào)整學(xué)習率。
- 同樣的,keras的callbacks模塊提供了回調(diào)函數(shù)用于減少學(xué)習率。這里的衰減是觸發(fā)執(zhí)行的,即后期monitor監(jiān)控的值不再變好的次數(shù)達到patience則會降低學(xué)習率。rl = keras.callbacks.ReduceLROnPlateau(monitor='val_acc', patience=5)
- 也可以在訓(xùn)練過程中,手動確定衰減策略降低學(xué)習率。optimizer.learning_rate = 0.2 * (100-epoch)/100
補充說明
- 本文主要針對TensorFlow2中訓(xùn)練技巧進行了簡單使用上的介紹。
- 博客同步至我的個人博客網(wǎng)站,歡迎瀏覽其他文章。
- 如有錯誤,歡迎指正。
總結(jié)
以上是生活随笔為你收集整理的TensorFlow2-网络训练技巧的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 机器学习-机器学习简介
- 下一篇: TensorFlow2-迁移学习