ONNX 模型分析與使用
地址:https://zhuanlan.zhihu.com/p/371177698
本文大部分內(nèi)容為對 ONNX 官方資料的總結(jié)和翻譯,部分知識點參考網(wǎng)上質(zhì)量高的博客。
01
深度學習算法大多通過計算數(shù)據(jù)流圖來完成神經(jīng)網(wǎng)絡的深度學習過程。一些框架(例如CNTK,Caffe2,Theano和TensorFlow)使用靜態(tài)圖形,而其他框架(例如 PyTorch 和 Chainer)使用動態(tài)圖形。但是這些框架都提供了接口,使開發(fā)人員可以輕松構(gòu)建計算圖和運行時,以優(yōu)化的方式處理圖。這些圖用作中間表示(IR),捕獲開發(fā)人員源代碼的特定意圖,有助于優(yōu)化和轉(zhuǎn)換在特定設備(CPU,GPU,F(xiàn)PGA等)上運行。
ONNX 的本質(zhì)只是一套開放的 ML 模型標準,模型文件存儲的只是網(wǎng)絡的拓撲結(jié)構(gòu)和權(quán)重(其實每個深度學習框架最后保存的模型都是類似的),脫離開框架是沒辦法對模型直接進行 inference的。
1.1 為什么使用通用 IR
現(xiàn)在很多的深度學習框架提供的功能都是類似的,但是在 API、計算圖和 runtime 方面卻是獨立的,這就給 AI 開發(fā)者在不同平臺部署不同模型帶來了很多困難和挑戰(zhàn),ONNX 的目的在于提供一個跨框架的模型中間表達框架,用于模型轉(zhuǎn)換和部署。ONNX 提供的計算圖是通用的,格式也是開源的。
02
Open Neural Network Exchange Intermediate Representation (ONNX IR) Specification.
ONNX 結(jié)構(gòu)的定義文件 .proto 和 .prpto3 可以在 onnx folder(https://github.com/onnx/onnx/tree/master/onnx) 目錄下找到,文件遵循的是谷歌 Protobuf 協(xié)議。ONNX 是一個開放式規(guī)范,由以下組件組成:
可擴展計算圖模型的定義
標準數(shù)據(jù)類型的定義
內(nèi)置運算符的定義
IR6 版本的 ONNX 只能用于推理(inference),從 IR7 開始 ONNX 支持訓練(training)。onnx.proto 主要的對象如下:
ModelProto
GraphProto
NodeProto
AttributeProto
ValueInfoProto
TensorProto
他們之間的關(guān)系:加載 ONNX 模型后會得到一個 ModelProto,它包含了一些版本信息,生產(chǎn)者信息和一個非常重要的 GraphProto;在 GraphProto 中包含了四個關(guān)鍵的 repeated 數(shù)組,分別是node (NodeProto 類型),input(ValueInfoProto 類型),output(ValueInfoProto 類型)和 initializer (TensorProto 類型),其中 node 中存放著模型中的所有計算節(jié)點,input 中存放著模型所有的輸入節(jié)點,output 存放著模型所有的輸出節(jié)點,initializer 存放著模型所有的權(quán)重;節(jié)點與節(jié)點之間的拓撲定義可以通過 input 和output 這兩個 string 數(shù)組的指向關(guān)系得到,這樣利用上述信息我們可以快速構(gòu)建出一個深度學習模型的拓撲圖。最后每個計算節(jié)點當中還包含了一個 AttributeProto 數(shù)組,用于描述該節(jié)點的屬性,例如 Conv 層的屬性包含 group,pads 和strides 等等,具體每個計算節(jié)點的屬性、輸入和輸出可以參考這個 Operators.md 文檔。
需要注意的是,上面所說的 GraphProto 中的 input 輸入數(shù)組不僅僅包含我們一般理解中的圖片輸入的那個節(jié)點,還包含了模型當中所有權(quán)重。舉例,Conv 層中的 W 權(quán)重實體是保存在 initializer 當中的,那么相應的會有一個同名的輸入在 input 當中,其背后的邏輯應該是把權(quán)重也看作是模型的輸入,并通過 initializer 中的權(quán)重實體來對這個輸入做初始化(也就是把值填充進來)
2.1 Model
模型結(jié)構(gòu)的主要目的是將元數(shù)據(jù)( meta data)與圖形(graph)相關(guān)聯(lián),圖形包含所有可執(zhí)行元素。首先,讀取模型文件時使用元數(shù)據(jù),為實現(xiàn)提供所需的信息,以確定它是否能夠:執(zhí)行模型,生成日志消息,錯誤報告等功能。此外元數(shù)據(jù)對工具很有用,例如IDE和模型庫,它需要它來告知用戶給定模型的目的和特征。
每個 model 有以下組件:

2.2 Operators Sets
每個模型必須明確命名它依賴于其功能的運算符集。操作員集定義可用的操作符,其版本和狀態(tài)。每個模型按其域定義導入的運算符集。所有模型都隱式導入默認的 ONNX 運算符集。
運算符集(Operators Sets)對象的屬性如下:

2.3 ONNX Operator
圖( graph)中使用的每個運算符必須由模型(model)導入的一個運算符集明確聲明。
運算符(Operator)對象定義的屬性如下:

2.4 ONNX Graph
序列化圖由一組元數(shù)據(jù)字段(metadata),模型參數(shù)列表(a list of model parameters,)和計算節(jié)點列表組成(a list of computation nodes)。每個計算數(shù)據(jù)流圖被構(gòu)造為拓撲排序的節(jié)點列表,這些節(jié)點形成圖形,其必須沒有周期。每個節(jié)點代表對運營商的呼叫。每個節(jié)點具有零個或多個輸入以及一個或多個輸出。
圖表(Graph)對象具有以下屬性:

2.5 ValueInfo
ValueInfo 對象屬性如下:

2.6 Standard data types
ONNX 標準有兩個版本,主要區(qū)別在于支持的數(shù)據(jù)類型和算子不同。計算圖 graphs、節(jié)點 nodes和計算圖的 initializers 支持的數(shù)據(jù)類型如下。原始數(shù)字,字符串和布爾類型必須用作張量的元素。
2.6.1 Tensor Element Types

2.6.2 Input / Output Data Types
以下類型用于定義計算圖和節(jié)點輸入和輸出的類型。

ONNX 現(xiàn)階段沒有定義稀疏張量類型。
03
3.1 加載模型
1. Loading an ONNX model
import onnx# onnx_model is an in-mempry ModelProtoonnx_model = onnx.load('path/to/the/model.onnx') # 加載 onnx 模型
2. Loading an ONNX Model with External Data
【默認加載模型方式】如果外部數(shù)據(jù)(external data)和模型文件在同一個目錄下,僅使用 onnx.load() 即可加載模型,方法見上小節(jié)。
如果外部數(shù)據(jù)(external data)和模型文件不在同一個目錄下,在使用 onnx_load() 函數(shù)后還需使用 load_external_data_for_model() 函數(shù)指定外部數(shù)據(jù)路徑。
import onnxfrom onnx.external_data_helper import load_external_data_for_modelonnx_model = onnx.load('path/to/the/model.onnx', load_external_data=False)load_external_data_for_model(onnx_model, 'data/directory/path/')# Then the onnx_model has loaded the external data from the specific directory
3. Converting an ONNX Model to External Data
from onnx.external_data_helper import convert_model_to_external_data# onnx_model is an in-memory ModelProtoonnx_model = ...convert_model_to_external_data(onnx_model, all_tensors_to_one_file=True, location='filename', size_threshold=1024, convert_attribute=False)# Then the onnx_model has converted raw data as external data# Must be followed by save
3.2 保存模型
1. Saving an ONNX Model
import onnx# onnx_model is an in-memory ModelProtoonnx_model = ...# Save the ONNX modelonnx.save(onnx_model, 'path/to/the/model.onnx')
2. Converting and Saving an ONNX Model to External Data
import onnx# onnx_model is an in-memory ModelProtoonnx_model = ...onnx.save_model(onnx_model, 'path/to/save/the/model.onnx', save_as_external_data=True, all_tensors_to_one_file=True, location='filename', size_threshold=1024, convert_attribute=False)# Then the onnx_model has converted raw data as external data and saved to specific directory
3.3 Manipulating TensorProto and Numpy Array
import numpyimport onnxfrom onnx import numpy_helper# Preprocessing: create a Numpy arraynumpy_array = numpy.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=float)print('Original Numpy array:\n{}\n'.format(numpy_array))# Convert the Numpy array to a TensorPrototensor = numpy_helper.from_array(numpy_array)print('TensorProto:\n{}'.format(tensor))# Convert the TensorProto to a Numpy arraynew_array = numpy_helper.to_array(tensor)print('After round trip, Numpy array:\n{}\n'.format(new_array))# Save the TensorProtowith open('tensor.pb', 'wb') as f:f.write(tensor.SerializeToString())# Load a TensorProtonew_tensor = onnx.TensorProto()with open('tensor.pb', 'rb') as f:new_tensor.ParseFromString(f.read())print('After saving and loading, new TensorProto:\n{}'.format(new_tensor))
3.4 創(chuàng)建ONNX模型
可以通過 helper 模塊提供的函數(shù) helper.make_graph 完成創(chuàng)建 ONNX 格式的模型。創(chuàng)建 graph 之前,需要先創(chuàng)建相應的 NodeProto(node),參照文檔設定節(jié)點的屬性,指定該節(jié)點的輸入與輸出,如果該節(jié)點帶有權(quán)重那還需要創(chuàng)建相應的ValueInfoProto 和 TensorProto 分別放入 graph 中的 input 和 initializer 中,以上步驟缺一不可。
import onnxfrom onnx import helperfrom onnx import AttributeProto, TensorProto, GraphProto# The protobuf definition can be found here:# https://github.com/onnx/onnx/blob/master/onnx/onnx.proto# Create one input (ValueInfoProto)X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2])pads = helper.make_tensor_value_info('pads', TensorProto.FLOAT, [1, 4])value = helper.make_tensor_value_info('value', AttributeProto.FLOAT, [1])# Create one output (ValueInfoProto)Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 4])# Create a node (NodeProto) - This is based on Pad-11node_def = helper.make_node('Pad', # name['X', 'pads', 'value'], # inputs['Y'], # outputsmode='constant', # attributes)# Create the graph (GraphProto)graph_def = helper.make_graph([node_def], # nodes'test-model', # name[X, pads, value], # inputs[Y], # outputs)# Create the model (ModelProto)model_def = helper.make_model(graph_def, producer_name='onnx-example')print('The model is:\n{}'.format(model_def))onnx.checker.check_model(model_def)print('The model is checked!')
3.5 檢查模型
在完成 ONNX 模型加載或者創(chuàng)建后,有必要對模型進行檢查,使用 onnx.check.check_model() 函數(shù)。
import onnx# Preprocessing: load the ONNX modelmodel_path = 'path/to/the/model.onnx'onnx_model = onnx.load(model_path)print('The model is:\n{}'.format(onnx_model))# Check the modeltry:onnx.checker.check_model(onnx_model)except onnx.checker.ValidationError as e:print('The model is invalid: %s' % e)else:print('The model is valid!')
3.6 實用功能函數(shù)
函數(shù) extract_model() 可以從 ONNX 模型中提取子模型,子模型由輸入和輸出張量的名稱定義。這個功能方便我們 debug 原模型和轉(zhuǎn)換后的 ONNX 模型輸出結(jié)果是否一致(誤差小于某個閾值),不再需要我們手動去修改 ONNX 模型。
import onnxinput_path = 'path/to/the/original/model.onnx'output_path = 'path/to/save/the/extracted/model.onnx'input_names = ['input_0', 'input_1', 'input_2']output_names = ['output_0', 'output_1']onnx.utils.extract_model(input_path, output_path, input_names, output_names)
3.7 工具
函數(shù) update_inputs_outputs_dims() 可以將模型輸入和輸出的維度更新為參數(shù)中指定的值,可以使用 dim_param 提供靜態(tài)和動態(tài)尺寸大小。
import onnxfrom onnx.tools import update_model_dimsmodel = onnx.load('path/to/the/model.onnx')# Here both 'seq', 'batch' and -1 are dynamic using dim_param.variable_length_model = update_model_dims.update_inputs_outputs_dims(model, {'input_name': ['seq', 'batch', 3, -1]}, {'output_name': ['seq', 'batch', 1, -1]})# need to check model after the input/output sizes are updatedonnx.checker.check_model(variable_length_model )
參考資料
https://zhuanlan.zhihu.com/p/41255090
https://bindog.github.io/blog/2020/03/13/deep-learning-model-convert-and-depoly/
https://github.com/onnx/tutorials
猜您喜歡:
附下載 |《TensorFlow 2.0 深度學習算法實戰(zhàn)》
《基于深度神經(jīng)網(wǎng)絡的少樣本學習綜述》
