onnx模型部署(一) ONNXRuntime
????通常我們在訓練模型時可以使用很多不同的框架,比如有的同學喜歡用 Pytorch,有的同學喜歡使用 TensorFLow,也有的喜歡 MXNet,以及深度學習最開始流行的 Caffe等等,這樣不同的訓練框架就導致了產生不同的模型結果包,在模型進行部署推理時就需要不同的依賴庫,而且同一個框架比如tensorflow 不同的版本之間的差異較大, 為了解決這個混亂問題, LF AI 這個組織聯合 Facebook, MicroSoft等公司制定了機器學習模型的標準,這個標準叫做ONNX, Open Neural Network Exchage,所有其他框架產生的模型包 (.pth, .pb) 都可以轉換成這個標準格式,轉換成這個標準格式后,就可以使用統一的 ONNX Runtime等工具進行統一部署。
????這其實可以和 JVM 對比,
A Java virtual machine (JVM) is a virtual machine that enables a computer to run Java programs as well as programs written in other languages that are also compiled to Java bytecode. The JVM is detailed by a specification that formally describes what is required in a JVM implementation. Having a specification ensures interoperability of Java programs across different implementations so that program authors using the Java Development Kit (JDK) need not worry about idiosyncrasies of the underlying hardware platform.
JAVA中有 JAVA 語言 + .jar 包 + JVM,同時還有其他的語言比如 Scala等也是建立在 JVM上運行的,因此不同的語言只要都最后將程序轉換成 JVM可以統一識別的格式,就可以在統一的跨平臺 JVM JAVA 虛擬機上運行。這里JVM使用的 包是二進制包,因此里面的內容是不可知的,人類難以直觀理解的。
這里 ONNX 標準采取了谷歌開發 protocal buffers 作為格式標準,這個格式是在 XML, json的基礎上發展的,是一個人類易理解的格式。ONNX 官網對ONNX的介紹如下:
ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.
ONNX支持的模型來源,基本上囊括了我們日常使用的所有框架:
ONNX的文件格式,采用的是谷歌的 protocal buffers,和 caffe采用的一致。
ONNX定義的數據類包括了我們常用的數據類型,用來定義模型中的輸出輸出格式
ONNX中定義了很多我們常用的節點,比如 Conv,ReLU,BN, maxpool等等約124種,同時也在不停地更新中,當遇到自帶節點庫中沒有的節點時,我們也可以自己寫一個節點
有了輸入輸出,以及計算節點,就可以根據 pytorch框架中的 forward 記錄一張模型從輸入圖片到輸出的計算圖,ONNX 就是將這張計算圖用標準的格式存儲下來了,可以通過一個工具 Netron對 ONNX 進行可視化,如第一張圖右側所示;
保存成統一的 ONNX 格式后,就可以使用統一的運行平臺來進行 inference。
pytorch原生支持 ONNX 格式轉碼,下面是實例:
1. 將pytorch模型轉換為onnx格式,直接傻瓜式調用 torch.onnx.export(model, input, output_name)
import torch from torchvision import modelsnet = models.resnet.resnet18(pretrained=True) dummpy_input = torch.randn(1,3,224,224) torch.onnx.export(net, dummpy_input, 'resnet18.onnx')2. 對生成的 onnx 進行查看
import onnx# Load the ONNX model model = onnx.load("resnet18.onnx")# Check that the IR is well formed onnx.checker.check_model(model)# Print a human readable representation of the graph print(onnx.helper.printable_graph(model.graph))支持ONNX的runtime就是類似于JVM將統一的ONNX格式的模型包運行起來,包括對ONNX 模型進行解讀,優化(融合conv-bn等操作),運行。
推理
完整代碼
import torch from torchvision import modelsnet = models.resnet.resnet18(pretrained=True) dummpy_input = torch.randn(1,3,224,224) torch.onnx.export(net, dummpy_input, 'resnet18.onnx')import onnx# Load the ONNX model model = onnx.load("resnet18.onnx")# Check that the IR is well formed onnx.checker.check_model(model)# Print a human readable representation of the graph print(onnx.helper.printable_graph(model.graph))import onnxruntime as rt import numpy as np data = np.array(np.random.randn(1,3,224,224)) sess = rt.InferenceSession('resnet18.onnx') input_name = sess.get_inputs()[0].name label_name = sess.get_outputs()[0].namepred_onx = sess.run([label_name], {input_name:data.astype(np.float32)})[0] print(pred_onx) print(np.argmax(pred_onx))完整代碼
總結
以上是生活随笔為你收集整理的onnx模型部署(一) ONNXRuntime的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 神经网络模型量化
- 下一篇: 包含目录、库目录、附加包含目录、附加库目