飞桨框架2.0RC新增模型保存、加载方案,与用户场景完美匹配,更全面、更易用
通過(guò)一段時(shí)間系統(tǒng)的課程學(xué)習(xí),算法攻城獅張同學(xué)對(duì)于飛槳框架的使用越來(lái)越順手,于是他打算在企業(yè)內(nèi)嘗試使用飛槳進(jìn)行AI產(chǎn)業(yè)落地。
但是AI產(chǎn)業(yè)落地并不是分秒鐘的事情,除了專(zhuān)業(yè)技能過(guò)硬,熟悉飛槳的使用外,在落地過(guò)程中還會(huì)遇到很多細(xì)節(jié)的問(wèn)題。這不,他就想到了兩個(gè)棘手的小問(wèn)題:
· 企業(yè)的數(shù)據(jù)集都比較大,使用這種大規(guī)模數(shù)據(jù)集進(jìn)行模型訓(xùn)練的耗時(shí)會(huì)很長(zhǎng),往往需要持續(xù)數(shù)天甚至更長(zhǎng)時(shí)間。這種情況下,就需要多次保存模型訓(xùn)練的參數(shù),避免由于訓(xùn)練意外中斷而前功盡棄。
模型訓(xùn)練至收斂后,需要將模型及參數(shù)保存下來(lái),用于后續(xù)在服務(wù)器或者移動(dòng)端環(huán)境中部署,在推理場(chǎng)景中發(fā)揮作用。
那么,如何高效地解決張同學(xué)提出的這兩個(gè)問(wèn)題呢?飛槳框架2.0RC為開(kāi)發(fā)者提供了全新的動(dòng)態(tài)圖模式下的模型保存與加載體系,其中包含兩個(gè)模型保存與加載的方案,分別適用于上述兩個(gè)場(chǎng)景。(友情提示:飛槳框架2.0RC版本開(kāi)始主推動(dòng)態(tài)圖模式,仍兼容保留對(duì)靜態(tài)圖模式的支持,但不再推薦使用。)
場(chǎng)景一:訓(xùn)練場(chǎng)景模型保存與加載(只需保存和加載模型參數(shù)即可)
在訓(xùn)練階段,開(kāi)發(fā)者僅需要保存和加載模型參數(shù)即可。飛槳提供了paddle.save和paddle.load接口用于實(shí)現(xiàn)該功能。當(dāng)保存和加載模型參數(shù)時(shí),可使用 paddle.save/load 結(jié)合Layer和Optimizer的state_dict()方法實(shí)現(xiàn),這兩個(gè)接口的關(guān)系入下圖所示:
- state_dict是保存Layer或者Optimizer參數(shù)的鍵值對(duì),state_dict的key為參數(shù)名,value為參數(shù)真實(shí)的numpy array數(shù)值;
- pdparams為L(zhǎng)ayer參數(shù)文件名的后綴;
- pdopt為Optimizer參數(shù)文件名的后綴。
相關(guān)文檔獲取地址:
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/framework/io/save_cn.html
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/framework/io/load_cn.html
下面舉一個(gè)簡(jiǎn)單的線(xiàn)性回歸模型示例。
如果想每訓(xùn)練一個(gè)epoch就保存一次模型參數(shù),要如何實(shí)現(xiàn)呢?此時(shí)只需要在每個(gè)epoch訓(xùn)練結(jié)束后,保存一次Layer和Optimizer的參數(shù)即可。
# train for epoch_id in range(EPOCH_NUM):for batch_id, (image, label) in enumerate(loader()):out = layer(image)loss = loss_fn(out, label)loss.backward()adam.step()adam.clear_grad()print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id, np.mean(loss.numpy())))# save state_dictpaddle.save(layer.state_dict(), "{}/epoch_{}.pdparams".format('checkpoints', epoch_id))paddle.save(adam.state_dict(),"{}/epoch_{}.pdopt".format('checkpoints', epoch_id))執(zhí)行該訓(xùn)練示例后,保存的結(jié)果如下,每個(gè)epoch執(zhí)行完都保存了相應(yīng)的訓(xùn)練參數(shù)。
λ ls checkpoints/ epoch_0.pdopt epoch_1.pdparams epoch_3.pdopt epoch_4.pdparams epoch_6.pdopt epoch_7.pdparams epoch_9.pdopt epoch_0.pdparams epoch_2.pdopt epoch_3.pdparams epoch_5.pdopt epoch_6.pdparams epoch_8.pdopt epoch_9.pdparams epoch_1.pdopt epoch_2.pdparams epoch_4.pdopt epoch_5.pdparams epoch_7.pdopt epoch_8.pdparams如果訓(xùn)練意外中斷,想要從某個(gè)epoch繼續(xù)訓(xùn)練,或者想要加載某個(gè)推理效果更好的epoch的保存結(jié)果,可以通過(guò)paddle.load接口加載,然后通過(guò)set_state_dict接口配置。這里以加載第8個(gè)epoch的訓(xùn)練參數(shù)為例,只需要在創(chuàng)建網(wǎng)絡(luò)之后,訓(xùn)練之前,將相應(yīng)文件加載配置到的Layer和Optimizer中即可。
# create network layer = LinearNet() loss_fn = nn.CrossEntropyLoss() adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())# load layer_state_dict = paddle.load("checkpoints/epoch_8.pdparams") opt_state_dict = paddle.load("checkpoints/epoch_8.pdopt")layer.set_state_dict(layer_state_dict) adam.set_state_dict(opt_state_dict)# create data loader ... # train ...場(chǎng)景二:推理&部署場(chǎng)景的模型保存與加載(需要同時(shí)保存推理模型的結(jié)構(gòu)和參數(shù))
在推理&部署場(chǎng)景中,需要同時(shí)保存推理模型的結(jié)構(gòu)和參數(shù),此時(shí)需要使用 paddle.jit.save和paddle.jit.load 接口實(shí)現(xiàn)。
paddle.jit.save接口會(huì)自動(dòng)調(diào)用飛槳框架2.0RC推出的動(dòng)態(tài)圖轉(zhuǎn)靜態(tài)圖功能,使得用戶(hù)可以做到使用動(dòng)態(tài)圖編程調(diào)試,自動(dòng)轉(zhuǎn)成靜態(tài)圖訓(xùn)練部署。
(小伙伴們應(yīng)該了解,動(dòng)態(tài)圖是即時(shí)執(zhí)行即時(shí)得到結(jié)果,并不會(huì)記錄模型的結(jié)構(gòu)信息。動(dòng)態(tài)圖在保存推理模型時(shí),需要先將動(dòng)態(tài)圖模型轉(zhuǎn)換為靜態(tài)圖寫(xiě)法,編譯得到對(duì)應(yīng)的模型結(jié)構(gòu)再保存,飛槳框架2.0RC版本推出的動(dòng)靜轉(zhuǎn)換體系,用于解決這個(gè)難題。)
這兩個(gè)接口的基本關(guān)系如下圖所示:
當(dāng)用戶(hù)使用paddle.jit.save保存Layer對(duì)象時(shí),飛槳會(huì)自動(dòng)將用戶(hù)編寫(xiě)的動(dòng)態(tài)圖Layer模型轉(zhuǎn)換為靜態(tài)圖寫(xiě)法,并編譯得到模型結(jié)構(gòu),同時(shí)將模型結(jié)構(gòu)與參數(shù)保存。paddle.jit.save需要適配飛槳沿用已久的推理模型與參數(shù)格式,做到前向完全兼容,因此其保存格式與paddle.save有所區(qū)別,具體包括三種文件:保存模型結(jié)構(gòu)的*.pdmodel文件;保存推理用參數(shù)的*.pdiparams文件和保存兼容變量信息的*.pdiparams.info文件,這幾個(gè)文件后綴均為paddle.jit.save保存時(shí)默認(rèn)使用的文件后綴。
仍然接著前面的模型示例追加說(shuō)明,直接在train實(shí)現(xiàn)后調(diào)用paddle.jit.save保存推理模型即可。
# save inference model from paddle.static import InputSpec paddle.jit.save(layer=layer,path="inference/linear",input_spec=[InputSpec(shape=[None, 784], dtype='float32')])此時(shí),inference目錄下的保存結(jié)果為:
λ ls inference/ linear.pdiparams linear.pdiparams.info linear.pdmodel這里InputSpec是用于描述推理模型輸入特性的對(duì)象,包括輸入Tensor的shape和dtype。paddle.jit.save會(huì)根據(jù)input_spec傳入的輸入描述信息,推理得到整個(gè)模型的結(jié)構(gòu),該場(chǎng)景中input_spec是必須指定的。
另外,在使用paddle.jit.save保存需要注意:確保Layer.forward方法中僅實(shí)現(xiàn)推理相關(guān)的功能,避免將訓(xùn)練所需的loss計(jì)算邏輯寫(xiě)入forward方法。Layer更準(zhǔn)確的語(yǔ)義是描述一個(gè)具有推理功能的模型對(duì)象,輸出推理的結(jié)果,而loss計(jì)算是僅屬于模型訓(xùn)練中的概念。將loss計(jì)算的實(shí)現(xiàn)放到Layer.forward方法中,會(huì)使Layer在不同場(chǎng)景下概念有所差別,并且增大Layer使用的復(fù)雜性,因此建議保持Layer實(shí)現(xiàn)的簡(jiǎn)潔性。舉個(gè)例子:
下面代碼是不推薦的Layer寫(xiě)法:
class LinearNet(nn.Layer):def __init__(self):super(LinearNet, self).__init__()self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)def forward(self, x, label=None):out = self._linear(x)if label:loss = nn.functional.cross_entropy(out, label)avg_loss = nn.functional.mean(loss)return out, avg_losselse:return out這才是推薦的Layer寫(xiě)法:
class LinearNet(nn.Layer):def __init__(self):super(LinearNet, self).__init__()self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)def forward(self, x, label=None):out = self._linear(x)if label:loss = nn.functional.cross_entropy(out, label)avg_loss = nn.functional.mean(loss)return out, avg_losselse:return out在模型加載方面,一般由paddle.jit.save保存的推理模型,會(huì)通過(guò)Paddle inference或者Paddle Lite等高效的推理工具加載并進(jìn)行線(xiàn)上部署,不會(huì)再通過(guò)飛槳基礎(chǔ)框架加載使用。但出于接口設(shè)計(jì)一致性的考慮,飛槳框架2.0RC新增了paddle.jit.load接口,也支持了通過(guò)飛槳基礎(chǔ)框架的接口加載paddle.jit.save保存的推理模型,且加載后可以用于推理,也可以用于繼續(xù)進(jìn)行增量訓(xùn)練。
- 加載后進(jìn)行推理:直接以保存之前的方式使用加載的對(duì)象即可,但要注意加載對(duì)象的輸入需要和保存時(shí)指定的input_spec保持一致,示例如下:
- 加載后進(jìn)行增量訓(xùn)練:由于保存的模型是用于inference的模型,所以需要重新為網(wǎng)絡(luò)添加loss function和optimizer,示例如下:
相關(guān)文檔獲取地址:
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/fluid/dygraph/jit/save_cn.html
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/fluid/dygraph/jit/load_cn.html
模型保存和加載是深度學(xué)習(xí)框架的基礎(chǔ)I/O模塊,是模型訓(xùn)練與部署的必用接口,整體關(guān)系如下圖所示:
在最新版本中,相應(yīng)的模型保存加載體系也有重大更新,在接口功能和易用性方面均有顯著提升。除上述功能外,模型保存與加載模塊還包含其他諸多易用的功能:
-
以上接口均兼容支持了從飛槳框架1.x的paddle.fluid.io.save_inference_model、 paddle.fluid.save 、 paddle.fluid.io.save_params/save_persistables 等接口保存的結(jié)果中加載模型或者參數(shù);
相關(guān)文檔獲取地址:
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0
rc1/guides/02_paddle2.0_develop/08_model_save_load_cn.html#id15 -
仍然支持直接保存靜態(tài)圖模型,并保留了相關(guān)接口;
-
推出了動(dòng)靜一體的高層API,也有相應(yīng)的保存/加載接口。更多內(nèi)容,可以查閱飛槳官網(wǎng)→API文檔(paddle.Modell.save/load)。
模型保存和加載體系是深度學(xué)習(xí)框架必不可少的部分,飛槳團(tuán)隊(duì)仍然在優(yōu)化相關(guān)接口的質(zhì)量和易用性,希望能給文中的張同學(xué)及廣大開(kāi)發(fā)者帶來(lái)更好的產(chǎn)品體驗(yàn)。如果大家發(fā)現(xiàn)模型保存與加載相關(guān)接口有BUG、出現(xiàn)覆蓋不到的場(chǎng)景等問(wèn)題,歡迎通過(guò)Issue反饋給我們。對(duì)于工匠精神的追求,飛槳一直在努力,并非常期待與廣大開(kāi)發(fā)者攜手并行,共同構(gòu)建功能強(qiáng)大且易用的開(kāi)源深度學(xué)習(xí)框架。
如果您想詳細(xì)了解更多飛槳的相關(guān)內(nèi)容,請(qǐng)參閱以下文檔。
·飛槳官網(wǎng)地址·
https://www.paddlepaddle.org.cn/
·飛槳開(kāi)源框架項(xiàng)目地址·
GitHub: https://github.com/PaddlePaddle/Paddle
Gitee: https://gitee.com/paddlepaddle/Paddle
期待你的加入
百度開(kāi)發(fā)者中心已開(kāi)啟征稿模式,歡迎開(kāi)發(fā)者登錄developer.baidu.com進(jìn)行投稿,優(yōu)質(zhì)文章將獲得豐厚獎(jiǎng)勵(lì)和推廣資源。
總結(jié)
以上是生活随笔為你收集整理的飞桨框架2.0RC新增模型保存、加载方案,与用户场景完美匹配,更全面、更易用的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: San介绍以及在百度APP的实践
- 下一篇: 百度牵头,全球首个面向商业化运营的Rob