ONNX初探
ONNX初探
轉載自:https://blog.csdn.net/just_sort/article/details/112912272
0x0. 背景
最近看了一些ONNX的資料,一個最大的感受就是這些資料太凌亂了。大多數都是在介紹ONNX模型轉換中碰到的坑點以及解決辦法。很少有文章可以系統的介紹ONNX的背景,分析ONNX格式,ONNX簡化方法等。所以,綜合了相當多資料之后我準備寫一篇ONNX相關的文章,希望對大家有用。
0x1. 什么是ONNX?
簡單描述一下官方介紹,開放神經網絡交換(Open Neural Network Exchange)簡稱ONNX是微軟和Facebook提出用來表示深度學習模型的開放格式。所謂開放就是ONNX定義了一組和環境,平臺均無關的標準格式,來增強各種AI模型的可交互性。
換句話說,無論你使用何種訓練框架訓練模型(比如TensorFlow/Pytorch/OneFlow/Paddle),在訓練完畢后你都可以將這些框架的模型統一轉換ONNX這種統一的格式進行存儲。注意ONNX文件不僅僅存儲了神經網絡模型的權重,同時也存儲了模型的結構信息以及網絡中每一層的輸入輸出和一些其它的輔助信息。我們直接從onnx的官方模型倉庫拉一個yolov3-tiny的onnx模型(地址為:https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/tiny-yolov3/model)用Netron可視化一下看看ONNX模型長什么樣子。
這里我們可以看到ONNX的版本信息,這個ONNX模型是由Keras導出來的,以及模型的輸入輸出等信息,如果你對模型的輸入輸出有疑問可以直接看:https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/tiny-yolov3/README.md。
在獲得ONNX模型之后,模型部署人員自然就可以將這個模型部署到兼容ONNX的運行環境中去。這里一般還會設計到額外的模型轉換工作,典型的比如在Android段利用NCNN部署模型,那么就需要將ONNX利用NCNN的轉換工具轉換到NCNN所支持的bin和param格式。
但在實際使用ONNX的過程中,大多數人對ONNX了解得并不多,僅僅認為它只是一個模型轉換工具人而已,可以利用它完成模型轉換和部署。正是因為對ONNX的不了解,在模型轉換過程中出現的各種不兼容或者不支持讓很多人浪費了大量時間。這篇文章將從理論和實踐2個方面談一談ONNX。
0x2. ProtoBuf簡介
在分析ONNX組織格式前我們需要了解ProtoBuf, 如果你非常了解ProtoBuf可以略過此節。
ONNX作為一個文件格式,我們自然需要一定的規則去讀取我們想要的信息或者是寫入我們需要保存信息。ONNX使用的是ProtoBuf這個序列化數據結構去存儲神經網絡的權重信息。熟悉Caffe或者Caffe2的同學應該知道,它們的模型存儲數據結構協議也是Protobuf。這個從安裝ONNX包的時候也可以看到:
Protobuf是一種輕便高效的結構化數據存儲格式,可以用于結構化數據串行化,或者說序列化。它很適合做數據存儲或數據交換格式。可用于通訊協議、數據存儲等領域的語言無關、平臺無關、可擴展的序列化結構數據格式。目前提供了 C++、Java、Python 三種語言的 API(摘自官方介紹)。
Protobuf協議是一個以***.proto后綴文件為基礎的,這個文件描述了用戶自定義的數據結構。如果需要了解更多細節請參考后面的資料3,這里只是想表達ONNX是基于Protobuf來做數據存儲和傳輸,那么自然onnx.proto**就是ONNX格式文件了,接下來我們就分析一下ONNX格式。
0x3. ONNX格式分析
這一節我們來分析一下ONNX的組織格式,上面提到ONNX中最核心的部分就是onnx.proto(https://github.com/onnx/onnx/blob/master/onnx/onnx.proto)這個文件了,它定義了ONNX這個數據協議的規則和一些其它信息。現在是2020年1月,這個文件有700多行,我們沒有必要把這個文件里面的每一行都貼出來,我們只要搞清楚里面的核心部分即可。在這個文件里面以message關鍵字開頭的對象是我們需要關心的。我們列一下最核心的幾個對象并解釋一下它們之間的關系。
- ModelProto
- GraphProto
- NodeProto
- ValueInfoProto
- TensorProto
- AttributeProto
當我們加載了一個ONNX之后,我們獲得的就是一個ModelProto,它包含了一些版本信息,生產者信息和一個GraphProto。在GraphProto里面又包含了四個repeated數組,它們分別是node(NodeProto類型),input(ValueInfoProto類型),output(ValueInfoProto類型)和initializer(TensorProto類型),其中node中存放了模型中所有的計算節點,input存放了模型的輸入節點,output存放了模型中所有的輸出節點,initializer存放了模型的所有權重參數。
我們知道要完整的表達一個神經網絡,不僅僅要知道網絡的各個節點信息,還要知道它們的拓撲關系。這個拓撲關系在ONNX中是如何表示的呢?ONNX的每個計算節點都會有input和output兩個數組,這兩個數組是string類型,通過input和output的指向關系,我們就可以利用上述信息快速構建出一個深度學習模型的拓撲圖。這里要注意一下,GraphProto中的input數組不僅包含我們一般理解中的圖片輸入的那個節點,還包含了模型中所有的權重。例如,Conv層里面的W權重實體是保存在initializer中的,那么相應的會有一個同名的輸入在input中,其背后的邏輯應該是把權重也看成模型的輸入,并通過initializer中的權重實體來對這個輸入做初始化,即一個賦值的過程。
最后,每個計算節點中還包含了一個AttributeProto數組,用來描述該節點的屬性,比如Conv節點或者說卷積層的屬性包含group,pad,strides等等,每一個計算節點的屬性,輸入輸出信息都詳細記錄在https://github.com/onnx/onnx/blob/master/docs/Operators.md。
0x4. onnx.helper
現在我們知道ONNX是把一個網絡的每一層或者說一個算子當成節點node,使用這些Node去構建一個Graph,即一個網絡。最后將Graph和其它的生產者信息,版本信息等合并在一起生成一個model,也即是最終的ONNX模型文件。
在構建ONNX模型的時候,https://github.com/onnx/onnx/blob/master/onnx/helper.py這個文件非常重要,我們可以利用它提供的make_node,make_graph,make_tensor等等接口完成一個ONNX模型的構建,一個示例如下:
這個官方示例為我們演示了如何使用onnx.helper的make_tensor,make_tensor_value_info,make_attribute,make_node,make_graph,make_node等方法來完整構建了一個ONNX模型。需要注意的是在上面的例子中,輸入數據是一個一維Tensor,初始維度為[2],這也是為什么經過維度為[1,4]的Pad操作之后獲得的輸出Tensor維度為[3,4]。另外由于Pad操作是沒有帶任何權重信息的,所以當你打印ONNX模型時,ModelProto的GraphProto是沒有initializer這個屬性的。
0x5. onnx-simplifier
原本這里是要總結一些使用ONNX進行模型部署經常碰到一些因為版本兼容性,或者各種框架OP沒有對齊等原因導致的各種BUG。但是這樣的話顯得篇幅會很長,所以這里以一個經典的Pytorch轉ONNX的reshape問題為例子,來嘗試講解一下大老師的onnx-simplifier,個人認為這個問題是基于ONNX模型部署最經典的問題,希望解決這個問題的過程中大家能有所收獲。
問題發生在當我們想把下面這段代碼導出ONNX模型時:
import torchclass JustReshape(torch.nn.Module):def __init__(self):super(JustReshape, self).__init__()def forward(self, x):return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2]))net = JustReshape() model_name = 'just_reshape.onnx' dummy_input = torch.randn(2, 3, 4, 5) torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])由于這個模型輸入維度是固定的,所以我們期望模型是這樣的:
但是,即使使用了ONNX的polished工具也只能獲得下面的模型:
要解決這個問題,有兩種方法,第一種是做一個強制類型轉換,將x.shape[0]類似的變量強制轉換為常量即int(x.shape[0]),或者使用大老師的onnx-simplifer來解決這一問題。之前一直好奇onnx-simplifer是怎么做的,最近對ONNX有了一些理解之后也能逐步看懂做法了。我來嘗試解釋一下。onnx-simplifer的核心思路就是利用onnxruntime推斷一遍ONNX的計算圖,然后使用常量輸出替代冗余的運算OP。
def simplify(model: Union[str, onnx.ModelProto], check_n: int = 0, perform_optimization: bool = True,skip_fuse_bn: bool = False, input_shapes: Optional[TensorShapes] = None, skipped_optimizers: Optional[Sequence[str]] = None, skip_shape_inference=False) \-> Tuple[onnx.ModelProto, bool]:if input_shapes is None:input_shapes = {}if type(model) == str:# 加載ONNX模型model = onnx.load(model)# 檢查ONNX模型格式是否正確,圖結構是否完整,節點是否正確等onnx.checker.check_model(model)# 深拷貝一份原始ONNX模型model_ori = copy.deepcopy(model)if not skip_shape_inference:# 獲取ONNX模型中特征圖的尺寸model = infer_shapes(model)input_shapes = check_and_update_input_shapes(model, input_shapes)if perform_optimization:model = optimize(model, skip_fuse_bn, skipped_optimizers)const_nodes = get_constant_nodes(model)res = forward_for_node_outputs(model, const_nodes, input_shapes=input_shapes)const_nodes = clean_constant_nodes(const_nodes, res)model = eliminate_const_nodes(model, const_nodes, res)onnx.checker.check_model(model)if not skip_shape_inference:model = infer_shapes(model)if perform_optimization:model = optimize(model, skip_fuse_bn, skipped_optimizers)check_ok = check(model_ori, model, check_n, input_shapes=input_shapes)return model, check_ok上面有一行:model = infer_shapes(model) 是獲取ONNX模型中特征圖的尺寸,它的具體實現如下:
def infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto:try:model = onnx.shape_inference.infer_shapes(model)except:passreturn model我們保存一下調用了這個接口之后的ONNX模型,并將其可視化看一下:
相對于原始的ONNX模型,現在每一條線都新增了一個shape信息,代表它的前一個特征圖的shape是怎樣的。
接著,程序使用到了check_and_update_input_shapes接口,這個接口的代碼示例如下,它可以用來判斷輸入的格式是否正確以及輸入模型是否有所有的指定輸入節點。
def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: TensorShapes) -> TensorShapes:input_names = get_input_names(model)if None in input_shapes:if len(input_names) == 1:input_shapes[input_names[0]] = input_shapes[None]del input_shapes[None]else:raise RuntimeError('The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape')for x in input_shapes:if x not in input_names:raise RuntimeError('The model doesn\'t have input named "{}"'.format(x))return input_shapes在這個例子中,如果我們指定input_shapes為:{‘input’: [2, 3, 4, 5]},那么這個函數的輸出也為{‘input’: [2, 3, 4, 5]}。如果不指定,輸出就是{}。驗證這個函數的調用代碼如下所示:
確定了輸入沒有問題之后,程序會根據用戶指定是否優化ONNX模型進入優化函數,函數定義如下:
def optimize(model: onnx.ModelProto, skip_fuse_bn: bool, skipped_optimizers: Optional[Sequence[str]]) -> onnx.ModelProto:""":model參數: 待優化的ONXX模型.:return: 優化之后的ONNX模型.簡化之前, 使用這個方法產生會在'forward_all'用到的ValueInfo簡化之后,使用這個方法去折疊前一步產生的常量到initializer中并且消除沒被使用的常量"""onnx.checker.check_model(model)onnx.helper.strip_doc_string(model)optimizers_list = ['eliminate_deadend','eliminate_nop_dropout','eliminate_nop_cast','eliminate_nop_monotone_argmax', 'eliminate_nop_pad','extract_constant_to_initializer', 'eliminate_unused_initializer','eliminate_nop_transpose','eliminate_nop_flatten', 'eliminate_identity','fuse_add_bias_into_conv','fuse_consecutive_concats','fuse_consecutive_log_softmax','fuse_consecutive_reduce_unsqueeze', 'fuse_consecutive_squeezes','fuse_consecutive_transposes', 'fuse_matmul_add_bias_into_gemm','fuse_pad_into_conv', 'fuse_transpose_into_gemm', 'eliminate_duplicate_initializer']if not skip_fuse_bn:optimizers_list.append('fuse_bn_into_conv')if skipped_optimizers is not None:for opt in skipped_optimizers:try:optimizers_list.remove(opt)except ValueError:passmodel = onnxoptimizer.optimize(model, optimizers_list,fixed_point=True)onnx.checker.check_model(model)return model這個函數的功能是對原始的ONNX模型做一些圖優化工作,比如merge_bn,fuse_add_bias_into_conv等等。我們使用onnx.save保存一下這個例子中圖優化后的模型,可以發現它和優化前的可視化效果是一樣的,如下圖所示:
這是因為在這個模型中是沒有上面列舉到的那些可以做圖優化的情況,但是當我們打印一下ONNX模型我們會發現optimize過后的ONNX模型多出一些initializer數組:
這些數組存儲的就是這個圖中那些常量OP的具體值,通過這個處理我們就可以調用get_constant_nodes函數來獲取ONNX模型的常量OP了,這個函數的詳細解釋如下:
def get_constant_nodes(m: onnx.ModelProto) -> List[onnx.NodeProto]:const_nodes = []# 如果節點的name在ONNX的GraphProto的initizlizer數組里面,它就是靜態的tensorconst_tensors = [x.name for x in m.graph.initializer]# 顯示的常量OP也加進來const_tensors.extend([node.output[0]for node in m.graph.node if node.op_type == 'Constant'])# 一些節點的輸出shape是由輸入節點決定的,我們認為這個節點的輸出shape并不是常量,# 所以我們不需要簡化這種節點dynamic_tensors = []# 判斷是否為動態OPdef is_dynamic(node):if node.op_type in ['NonMaxSuppression', 'NonZero', 'Unique'] and node.input[0] not in const_tensors:return Trueif node.op_type in ['Reshape', 'Expand', 'Upsample', 'ConstantOfShape'] and len(node.input) > 1 and node.input[1] not in const_tensors:return Trueif node.op_type in ['Resize'] and ((len(node.input) > 2 and node.input[2] not in const_tensors) or (len(node.input) > 3 and node.input[3] not in const_tensors)):return Truereturn Falsefor node in m.graph.node:if any(x in dynamic_tensors for x in node.input):dynamic_tensors.extend(node.output)elif node.op_type == 'Shape':const_nodes.append(node)const_tensors.extend(node.output)elif is_dynamic(node):dynamic_tensors.extend(node.output)elif all([x in const_tensors for x in node.input]):const_nodes.append(node)const_tensors.extend(node.output)# 深拷貝return copy.deepcopy(const_nodes)在這個例子中,我們打印一下通過這個獲取常量OP二點函數之后,Graph中有哪些節點被看成了常量節點。
獲取了模型中所有的常量OP之后,我們需要把所有的靜態節點擴展到ONNX Graph的輸出節點列表中,然后利用onnxruntme執行一次forward:
def forward_for_node_outputs(model: onnx.ModelProto, nodes: List[onnx.NodeProto],input_shapes: Optional[TensorShapes] = None) -> Dict[str, np.ndarray]:if input_shapes is None:input_shapes = {}model = copy.deepcopy(model)# nodes 是Graph中所有的靜態OPadd_features_to_output(model, nodes)res = forward(model, input_shapes=input_shapes)return res其中add_features_to_output的定義如下:
def add_features_to_output(m: onnx.ModelProto, nodes: List[onnx.NodeProto]) -> None:"""Add features to output in pb, so that ONNX Runtime will output them.:param m: the model that will be run in ONNX Runtime:param nodes: nodes whose outputs will be added into the graph outputs"""# ONNX模型的graph擴展輸出節點,獲取所有靜態OP的輸出和原始輸出節點的輸出for node in nodes:for output in node.output:m.graph.output.extend([onnx.ValueInfoProto(name=output)])最后的forward函數就是利用onnxruntime推理獲得我們指定的輸出節點的值。這個函數這里不進行解釋。推理完成之后,進入下一個函數clean_constant_nodes,這個函數的定義如下:
def clean_constant_nodes(const_nodes: List[onnx.NodeProto], res: Dict[str, np.ndarray]):"""It seems not needed since commit 6f2a72, but maybe it still prevents some unknown bug:param const_nodes: const nodes detected by `get_constant_nodes`:param res: The dict containing all tensors, got by `forward_all`:return: The constant nodes which have an output in res"""return [node for node in const_nodes if node.output[0] in res]這個函數是用來清洗那些沒有被onnxruntime推理的靜態節點,但通過上面的optimize邏輯,我們的graph中其實已經不存在這個情況了(沒有被onnxruntime推理的靜態節點在圖優化階段會被優化掉),因此這個函數理論上是可以刪除的。這個地方是為了避免刪除掉有可能引發其它問題就保留了。
不過從一些實際經驗來看,還是保留吧,畢竟不能保證ONNX的圖優化就完全正確,前段時間剛發現了TensorRT圖優化出了一個BUG。保留這個函數可以提升一些程序的穩定性。
接下來就是這個onnx-simplifier最核心的步驟了,即將常量節點從原始的ONNX刪除,函數接口為eliminate_const_nodes:
def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: List[onnx.NodeProto],res: Dict[str, np.ndarray]) -> onnx.ModelProto:""":model參數: 原始ONNX模型:const_nodes參數: 使用`get_constant_nodes`獲得的靜態OP:res參數: 包含所有輸出Tensor的字典:return: 簡化后的模型. 所有冗余操作都已刪除."""for i, node in enumerate(model.graph.node):if node in const_nodes:for output in node.output:new_node = copy.deepcopy(node)new_node.name = "node_" + outputnew_node.op_type = 'Constant'new_attr = onnx.helper.make_attribute('value',onnx.numpy_helper.from_array(res[output], name=output))del new_node.input[:]del new_node.attribute[:]del new_node.output[:]new_node.output.extend([output])new_node.attribute.extend([new_attr])insert_elem(model.graph.node, i + 1, new_node)del model.graph.node[i]return model運行這個函數之后我們獲得的ONNX模型可視化結果是這樣子的:
注意,這里獲得的ONNX模型中雖然常量節點已經從Graph中斷開了,即相當于這個DAG里面多了一些單獨的點,但是這些點還是存在的。因此,我們最后還需要執行一次optimize就可以獲得最終簡化后的ONNX模型了。最終簡化后的ONNX模型如下圖所示:
0x6. 總結
好了,介于篇幅原因,介紹ONNX的第一篇文章就介紹到這里了,后續可能會結合更多實踐的經驗來談談ONNX了,例如OneFlow模型導出ONNX進行部署?。總之,文章很長,謝謝你的觀看,希望這篇文章有幫助到你。最后歡迎star大老師的onnx-simplifier。
0x7. 參考資料
- 【1】https://zhuanlan.zhihu.com/p/86867138
- 【2】https://oldpan.me/archives/talk-about-onnx
- 【3】https://blog.csdn.net/chengzi_comm/article/details/53199278
- 【4】https://www.jianshu.com/p/a24c88c0526a
- 【5】https://bindog.github.io/blog/2020/03/13/deep-learning-model-convert-and-depoly/
- 【6】 https://github.com/daquexian/onnx-simplifier
總結
- 上一篇: 分割线不显示_90后都30岁了,为什么还
- 下一篇: TabError- inconsiste