Tensorflow2 model.compile()理解
在TensorFLow2中進行神經網絡模型的訓練主要包括以下幾個主要的步驟:
- 導入相關模塊import
- 準備數據,拆分訓練集train、測試集test
- 搭建神經網絡模型model (兩種方法:Sequential或自定義模型class)
- 模型編譯model.compile()
- 模型訓練model.fit()
- 查看模型model.summary()
- 模型評價
- 模型預測model.predict()
model.compile()的作用就是為搭建好的神經網絡模型設置損失函數loss、優化器optimizer、準確性評價函數metrics。
這些方法的作用分別是:
- 損失函數和優化器用在反向傳播的時候,我們會求損失函數對訓練變量的導數,即梯度,然后根據選擇的優化器來確定參數更新公式,根據公式對可訓練參數進行更新。
- 準確性評價函數用在評估模型預測的準確性。在模型訓練的過程中,我們會記錄模型在訓練集、驗證集上的預測準確性,之后會據此繪制準確率隨著訓練次數的變化曲線。通過查看和對比訓練集、測試集隨著訓練次數的準確率曲線,我們能發現模型是否是過擬合、欠擬合,或者也可以發現多少輪后可以停止模型訓練了。
由上可以看出,神經網絡模型建模訓練的過程中,核心的靈魂環節就是搭建模型和編譯compile了。所以,這是非常非常重要的一個模塊。
1、首先,上代碼,直觀看下model.compile()在神經網絡建模中的使用示例
#model.compile()配置模型訓練方法 model.compile( optimizer = tf.keras.optimizers.SGD(lr = 0.1), #使用SGD優化器,學習率為0.1loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False), #配置損失函數metrics = ['sparse_categorical_accuracy'] #標注準確性評價指標 )2、解讀model.compile()中配置方法
compile(optimizer, #優化器loss=None, #損失函數metrics=None, # ["準確率”]loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None)2.1 ?loss可以是字符串形式給出的損失函數的名字,也可以是函數形式
例如:”mse" 或者 tf.keras.losses.MeanSquaredError()
? ? ? ? ? ?"sparse_categorical_crossentropy" ?或者 ?tf.keras.losses.SparseCatagoricalCrossentropy(from_logits = False)
? ?損失函數經常需要使用softmax函數來將輸出轉化為概率分布的形式,在這里from_logits代表是否將輸出轉為概率分布的形式,為False時表示轉換為概率分布,為True時表示不轉換,直接輸出
2.2 ?optimizer可以是字符串形式給出的優化器名字,也可以是函數形式,使用函數形式可以設置學習率、動量和超參數
例如:“sgd” ? 或者 ? tf.optimizers.SGD(lr = 學習率,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?decay = 學習率衰減率,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? momentum = 動量參數)
? ? ? ? ? ?“adagrad" ?或者 ?tf.keras.optimizers.Adagrad(lr = 學習率,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?decay = 學習率衰減率)
? ? ? ? ? ? ”adadelta" ?或者 ?tf.keras.optimizers.Adadelta(lr = 學習率,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?decay = 學習率衰減率)
? ? ? ? ? ? ?“adam" ?或者 ?tf.keras.optimizers.Adam(lr = 學習率,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? decay = 學習率衰減率)
2.3 Metrics神經網絡模型的準確性評價指標
例如:
? ? ? ? "accuracy" : y_ 和 y 都是數值,如y_ = [1] y = [1] ?#y_為真實值,y為預測值
? ? ? ? “sparse_accuracy":y_和y都是以獨熱碼 和概率分布表示,如y_ = [0, 1, 0], y = [0.256, 0.695, 0.048]
? ? ? ? "sparse_categorical_accuracy" :y_是以數值形式給出,y是以 獨熱碼給出,如y_ = [1], y = [0.256 0.695, 0.048]
?
總結
以上是生活随笔為你收集整理的Tensorflow2 model.compile()理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: MindMapper中的鱼骨图该怎样进行
- 下一篇: 漫画网站服务器,漫画网站服务器配置