飞桨框架2.0RC新增模型保存、加载方案,与用户场景完美匹配,更全面、更易用
通過一段時間系統(tǒng)的課程學(xué)習(xí),算法攻城獅張同學(xué)對于飛槳框架的使用越來越順手,于是他打算在企業(yè)內(nèi)嘗試使用飛槳進(jìn)行AI產(chǎn)業(yè)落地。
但是AI產(chǎn)業(yè)落地并不是分秒鐘的事情,除了專業(yè)技能過硬,熟悉飛槳的使用外,在落地過程中還會遇到很多細(xì)節(jié)的問題。這不,他就想到了兩個棘手的小問題:
· 企業(yè)的數(shù)據(jù)集都比較大,使用這種大規(guī)模數(shù)據(jù)集進(jìn)行模型訓(xùn)練的耗時會很長,往往需要持續(xù)數(shù)天甚至更長時間。這種情況下,就需要多次保存模型訓(xùn)練的參數(shù),避免由于訓(xùn)練意外中斷而前功盡棄。
模型訓(xùn)練至收斂后,需要將模型及參數(shù)保存下來,用于后續(xù)在服務(wù)器或者移動端環(huán)境中部署,在推理場景中發(fā)揮作用。
那么,如何高效地解決張同學(xué)提出的這兩個問題呢?飛槳框架2.0RC為開發(fā)者提供了全新的動態(tài)圖模式下的模型保存與加載體系,其中包含兩個模型保存與加載的方案,分別適用于上述兩個場景。(友情提示:飛槳框架2.0RC版本開始主推動態(tài)圖模式,仍兼容保留對靜態(tài)圖模式的支持,但不再推薦使用。)
場景一:訓(xùn)練場景模型保存與加載(只需保存和加載模型參數(shù)即可)
在訓(xùn)練階段,開發(fā)者僅需要保存和加載模型參數(shù)即可。飛槳提供了paddle.save和paddle.load接口用于實現(xiàn)該功能。當(dāng)保存和加載模型參數(shù)時,可使用 paddle.save/load 結(jié)合Layer和Optimizer的state_dict()方法實現(xiàn),這兩個接口的關(guān)系入下圖所示:
- state_dict是保存Layer或者Optimizer參數(shù)的鍵值對,state_dict的key為參數(shù)名,value為參數(shù)真實的numpy array數(shù)值;
- pdparams為Layer參數(shù)文件名的后綴;
- pdopt為Optimizer參數(shù)文件名的后綴。
相關(guān)文檔獲取地址:
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/framework/io/save_cn.html
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/framework/io/load_cn.html
下面舉一個簡單的線性回歸模型示例。
如果想每訓(xùn)練一個epoch就保存一次模型參數(shù),要如何實現(xiàn)呢?此時只需要在每個epoch訓(xùn)練結(jié)束后,保存一次Layer和Optimizer的參數(shù)即可。
# train for epoch_id in range(EPOCH_NUM):for batch_id, (image, label) in enumerate(loader()):out = layer(image)loss = loss_fn(out, label)loss.backward()adam.step()adam.clear_grad()print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id, np.mean(loss.numpy())))# save state_dictpaddle.save(layer.state_dict(), "{}/epoch_{}.pdparams".format('checkpoints', epoch_id))paddle.save(adam.state_dict(),"{}/epoch_{}.pdopt".format('checkpoints', epoch_id))執(zhí)行該訓(xùn)練示例后,保存的結(jié)果如下,每個epoch執(zhí)行完都保存了相應(yīng)的訓(xùn)練參數(shù)。
λ ls checkpoints/ epoch_0.pdopt epoch_1.pdparams epoch_3.pdopt epoch_4.pdparams epoch_6.pdopt epoch_7.pdparams epoch_9.pdopt epoch_0.pdparams epoch_2.pdopt epoch_3.pdparams epoch_5.pdopt epoch_6.pdparams epoch_8.pdopt epoch_9.pdparams epoch_1.pdopt epoch_2.pdparams epoch_4.pdopt epoch_5.pdparams epoch_7.pdopt epoch_8.pdparams如果訓(xùn)練意外中斷,想要從某個epoch繼續(xù)訓(xùn)練,或者想要加載某個推理效果更好的epoch的保存結(jié)果,可以通過paddle.load接口加載,然后通過set_state_dict接口配置。這里以加載第8個epoch的訓(xùn)練參數(shù)為例,只需要在創(chuàng)建網(wǎng)絡(luò)之后,訓(xùn)練之前,將相應(yīng)文件加載配置到的Layer和Optimizer中即可。
# create network layer = LinearNet() loss_fn = nn.CrossEntropyLoss() adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())# load layer_state_dict = paddle.load("checkpoints/epoch_8.pdparams") opt_state_dict = paddle.load("checkpoints/epoch_8.pdopt")layer.set_state_dict(layer_state_dict) adam.set_state_dict(opt_state_dict)# create data loader ... # train ...場景二:推理&部署場景的模型保存與加載(需要同時保存推理模型的結(jié)構(gòu)和參數(shù))
在推理&部署場景中,需要同時保存推理模型的結(jié)構(gòu)和參數(shù),此時需要使用 paddle.jit.save和paddle.jit.load 接口實現(xiàn)。
paddle.jit.save接口會自動調(diào)用飛槳框架2.0RC推出的動態(tài)圖轉(zhuǎn)靜態(tài)圖功能,使得用戶可以做到使用動態(tài)圖編程調(diào)試,自動轉(zhuǎn)成靜態(tài)圖訓(xùn)練部署。
(小伙伴們應(yīng)該了解,動態(tài)圖是即時執(zhí)行即時得到結(jié)果,并不會記錄模型的結(jié)構(gòu)信息。動態(tài)圖在保存推理模型時,需要先將動態(tài)圖模型轉(zhuǎn)換為靜態(tài)圖寫法,編譯得到對應(yīng)的模型結(jié)構(gòu)再保存,飛槳框架2.0RC版本推出的動靜轉(zhuǎn)換體系,用于解決這個難題。)
這兩個接口的基本關(guān)系如下圖所示:
當(dāng)用戶使用paddle.jit.save保存Layer對象時,飛槳會自動將用戶編寫的動態(tài)圖Layer模型轉(zhuǎn)換為靜態(tài)圖寫法,并編譯得到模型結(jié)構(gòu),同時將模型結(jié)構(gòu)與參數(shù)保存。paddle.jit.save需要適配飛槳沿用已久的推理模型與參數(shù)格式,做到前向完全兼容,因此其保存格式與paddle.save有所區(qū)別,具體包括三種文件:保存模型結(jié)構(gòu)的*.pdmodel文件;保存推理用參數(shù)的*.pdiparams文件和保存兼容變量信息的*.pdiparams.info文件,這幾個文件后綴均為paddle.jit.save保存時默認(rèn)使用的文件后綴。
仍然接著前面的模型示例追加說明,直接在train實現(xiàn)后調(diào)用paddle.jit.save保存推理模型即可。
# save inference model from paddle.static import InputSpec paddle.jit.save(layer=layer,path="inference/linear",input_spec=[InputSpec(shape=[None, 784], dtype='float32')])此時,inference目錄下的保存結(jié)果為:
λ ls inference/ linear.pdiparams linear.pdiparams.info linear.pdmodel這里InputSpec是用于描述推理模型輸入特性的對象,包括輸入Tensor的shape和dtype。paddle.jit.save會根據(jù)input_spec傳入的輸入描述信息,推理得到整個模型的結(jié)構(gòu),該場景中input_spec是必須指定的。
另外,在使用paddle.jit.save保存需要注意:確保Layer.forward方法中僅實現(xiàn)推理相關(guān)的功能,避免將訓(xùn)練所需的loss計算邏輯寫入forward方法。Layer更準(zhǔn)確的語義是描述一個具有推理功能的模型對象,輸出推理的結(jié)果,而loss計算是僅屬于模型訓(xùn)練中的概念。將loss計算的實現(xiàn)放到Layer.forward方法中,會使Layer在不同場景下概念有所差別,并且增大Layer使用的復(fù)雜性,因此建議保持Layer實現(xiàn)的簡潔性。舉個例子:
下面代碼是不推薦的Layer寫法:
class LinearNet(nn.Layer):def __init__(self):super(LinearNet, self).__init__()self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)def forward(self, x, label=None):out = self._linear(x)if label:loss = nn.functional.cross_entropy(out, label)avg_loss = nn.functional.mean(loss)return out, avg_losselse:return out這才是推薦的Layer寫法:
class LinearNet(nn.Layer):def __init__(self):super(LinearNet, self).__init__()self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)def forward(self, x, label=None):out = self._linear(x)if label:loss = nn.functional.cross_entropy(out, label)avg_loss = nn.functional.mean(loss)return out, avg_losselse:return out在模型加載方面,一般由paddle.jit.save保存的推理模型,會通過Paddle inference或者Paddle Lite等高效的推理工具加載并進(jìn)行線上部署,不會再通過飛槳基礎(chǔ)框架加載使用。但出于接口設(shè)計一致性的考慮,飛槳框架2.0RC新增了paddle.jit.load接口,也支持了通過飛槳基礎(chǔ)框架的接口加載paddle.jit.save保存的推理模型,且加載后可以用于推理,也可以用于繼續(xù)進(jìn)行增量訓(xùn)練。
- 加載后進(jìn)行推理:直接以保存之前的方式使用加載的對象即可,但要注意加載對象的輸入需要和保存時指定的input_spec保持一致,示例如下:
- 加載后進(jìn)行增量訓(xùn)練:由于保存的模型是用于inference的模型,所以需要重新為網(wǎng)絡(luò)添加loss function和optimizer,示例如下:
相關(guān)文檔獲取地址:
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/fluid/dygraph/jit/save_cn.html
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/fluid/dygraph/jit/load_cn.html
模型保存和加載是深度學(xué)習(xí)框架的基礎(chǔ)I/O模塊,是模型訓(xùn)練與部署的必用接口,整體關(guān)系如下圖所示:
在最新版本中,相應(yīng)的模型保存加載體系也有重大更新,在接口功能和易用性方面均有顯著提升。除上述功能外,模型保存與加載模塊還包含其他諸多易用的功能:
-
以上接口均兼容支持了從飛槳框架1.x的paddle.fluid.io.save_inference_model、 paddle.fluid.save 、 paddle.fluid.io.save_params/save_persistables 等接口保存的結(jié)果中加載模型或者參數(shù);
相關(guān)文檔獲取地址:
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0
rc1/guides/02_paddle2.0_develop/08_model_save_load_cn.html#id15 -
仍然支持直接保存靜態(tài)圖模型,并保留了相關(guān)接口;
-
推出了動靜一體的高層API,也有相應(yīng)的保存/加載接口。更多內(nèi)容,可以查閱飛槳官網(wǎng)→API文檔(paddle.Modell.save/load)。
模型保存和加載體系是深度學(xué)習(xí)框架必不可少的部分,飛槳團(tuán)隊仍然在優(yōu)化相關(guān)接口的質(zhì)量和易用性,希望能給文中的張同學(xué)及廣大開發(fā)者帶來更好的產(chǎn)品體驗。如果大家發(fā)現(xiàn)模型保存與加載相關(guān)接口有BUG、出現(xiàn)覆蓋不到的場景等問題,歡迎通過Issue反饋給我們。對于工匠精神的追求,飛槳一直在努力,并非常期待與廣大開發(fā)者攜手并行,共同構(gòu)建功能強大且易用的開源深度學(xué)習(xí)框架。
如果您想詳細(xì)了解更多飛槳的相關(guān)內(nèi)容,請參閱以下文檔。
·飛槳官網(wǎng)地址·
https://www.paddlepaddle.org.cn/
·飛槳開源框架項目地址·
GitHub: https://github.com/PaddlePaddle/Paddle
Gitee: https://gitee.com/paddlepaddle/Paddle
期待你的加入
百度開發(fā)者中心已開啟征稿模式,歡迎開發(fā)者登錄developer.baidu.com進(jìn)行投稿,優(yōu)質(zhì)文章將獲得豐厚獎勵和推廣資源。
總結(jié)
以上是生活随笔為你收集整理的飞桨框架2.0RC新增模型保存、加载方案,与用户场景完美匹配,更全面、更易用的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: San介绍以及在百度APP的实践
- 下一篇: 百度牵头,全球首个面向商业化运营的Rob