ONNX初探
0x0. 背景
最近看了一些ONNX的資料,一個最大的感受就是這些資料太凌亂了。大多數都是在介紹ONNX模型轉換中碰到的坑點以及解決辦法。很少有文章可以系統的介紹ONNX的背景,分析ONNX格式,ONNX簡化方法等。所以,綜合了相當多資料之后我準備寫一篇ONNX相關的文章,希望對大家有用。
0x1. 什么是ONNX?
簡單描述一下官方介紹,開放神經網絡交換(Open Neural Network Exchange)簡稱ONNX是微軟和Facebook提出用來表示深度學習模型的開放格式。所謂開放就是ONNX定義了一組和環(huán)境,平臺均無關的標準格式,來增強各種AI模型的可交互性。
換句話說,無論你使用何種訓練框架訓練模型(比如TensorFlow/Pytorch/OneFlow/Paddle),在訓練完畢后你都可以將這些框架的模型統一轉換為ONNX這種統一的格式進行存儲。注意ONNX文件不僅僅存儲了神經網絡模型的權重,同時也存儲了模型的結構信息以及網絡中每一層的輸入輸出和一些其它的輔助信息。我們直接從onnx的官方模型倉庫拉一個yolov3-tiny的onnx模型(地址為:https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/tiny-yolov3/model)用Netron可視化一下看看ONNX模型長什么樣子。

這里我們可以看到ONNX的版本信息,這個ONNX模型是由Keras導出來的,以及模型的輸入輸出等信息,如果你對模型的輸入輸出有疑問可以直接看:https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/tiny-yolov3/README.md。
在獲得ONNX模型之后,模型部署人員自然就可以將這個模型部署到兼容ONNX的運行環(huán)境中去。這里一般還會設計到額外的模型轉換工作,典型的比如在Android端利用NCNN部署ONNX格式模型,那么就需要將ONNX利用NCNN的轉換工具轉換到NCNN所支持的bin和param格式。
但在實際使用ONNX的過程中,大多數人對ONNX了解得并不多,僅僅認為它只是一個完成模型轉換和部署工具人而已,我們可以利用它完成模型轉換和部署。正是因為對ONNX的不了解,在模型轉換過程中出現的各種不兼容或者不支持讓很多人浪費了大量時間。這篇文章將從理論和實踐2個方面談一談ONNX。
0x2. ProtoBuf簡介
在分析ONNX組織格式前我們需要了解Protobuf, 如果你比較了解Protobuf可以略過此節(jié)。ONNX作為一個文件格式,我們自然需要一定的規(guī)則去讀取我們想要的信息或者是寫入我們需要保存信息。ONNX使用的是Protobuf這個序列化數據結構去存儲神經網絡的權重信息。熟悉Caffe或者Caffe2的同學應該知道,它們的模型存儲數據結構協議也是Protobuf。這個從安裝ONNX包的時候也可以看到:

Protobuf是一種輕便高效的結構化數據存儲格式,可以用于結構化數據串行化,或者說序列化。它很適合做數據存儲或數據交換格式。可用于通訊協議、數據存儲等領域的語言無關、平臺無關、可擴展的序列化結構數據格式。目前提供了 C++、Java、Python 三種語言的 API(摘自官方介紹)。
Protobuf協議是一個以*.proto后綴文件為基礎的,這個文件描述了用戶自定義的數據結構。如果需要了解更多細節(jié)請參考0x7節(jié)的資料3,這里只是想表達ONNX是基于Protobuf來做數據存儲和傳輸,那么自然onnx.proto就是ONNX格式文件了,接下來我們就分析一下ONNX格式。
0x3. ONNX格式分析
這一節(jié)我們來分析一下ONNX的組織格式,上面提到ONNX中最核心的部分就是onnx.proto(https://github.com/onnx/onnx/blob/master/onnx/onnx.proto)這個文件了,它定義了ONNX這個數據協議的規(guī)則和一些其它信息。現在是2021年1月,這個文件有700多行,我們沒有必要把這個文件里面的每一行都貼出來,我們只要搞清楚里面的核心部分即可。在這個文件里面以message關鍵字開頭的對象是我們需要關心的。我們列一下最核心的幾個對象并解釋一下它們之間的關系。
ModelProtoGraphProtoNodeProtoValueInfoProtoTensorProtoAttributeProto
當我們加載了一個ONNX之后,我們獲得的就是一個ModelProto,它包含了一些版本信息,生產者信息和一個GraphProto。在GraphProto里面又包含了四個repeated數組,它們分別是node(NodeProto類型),input(ValueInfoProto類型),output(ValueInfoProto類型)和initializer(TensorProto類型),其中node中存放了模型中所有的計算節(jié)點,input存放了模型的輸入節(jié)點,output存放了模型中所有的輸出節(jié)點,initializer存放了模型的所有權重參數。
我們知道要完整的表達一個神經網絡,不僅僅要知道網絡的各個節(jié)點信息,還要知道它們的拓撲關系。這個拓撲關系在ONNX中是如何表示的呢?ONNX的每個計算節(jié)點都會有input和output兩個數組,這兩個數組是string類型,通過input和output的指向關系,我們就可以利用上述信息快速構建出一個深度學習模型的拓撲圖。這里要注意一下,GraphProto中的input數組不僅包含我們一般理解中的圖片輸入的那個節(jié)點,還包含了模型中所有的權重。例如,Conv層里面的W權重實體是保存在initializer中的,那么相應的會有一個同名的輸入在input中,其背后的邏輯應該是把權重也看成模型的輸入,并通過initializer中的權重實體來對這個輸入做初始化,即一個賦值的過程。
最后,每個計算節(jié)點中還包含了一個AttributeProto數組,用來描述該節(jié)點的屬性,比如Conv節(jié)點或者說卷積層的屬性包含group,pad,strides等等,每一個計算節(jié)點的屬性,輸入輸出信息都詳細記錄在https://github.com/onnx/onnx/blob/master/docs/Operators.md。
0x4. onnx.helper
現在我們知道ONNX是把一個網絡的每一層或者說一個算子當成節(jié)點node,使用這些Node去構建一個Graph,即一個網絡。最后將Graph和其它的生產者信息,版本信息等合并在一起生成一個Model,也即是最終的ONNX模型文件。在構建ONNX模型的時候,https://github.com/onnx/onnx/blob/master/onnx/helper.py這個文件非常重要,我們可以利用它提供的make_node,make_graph,make_tensor等等接口完成一個ONNX模型的構建,一個示例如下:
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
# The protobuf definition can be found here:
# https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
# Create one input (ValueInfoProto)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2])
pads = helper.make_tensor_value_info('pads', TensorProto.FLOAT, [1, 4])
value = helper.make_tensor_value_info('value', AttributeProto.FLOAT, [1])
# Create one output (ValueInfoProto)
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 4])
# Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
'Pad', # node name
['X', 'pads', 'value'], # inputs
['Y'], # outputs
mode='constant', # attributes
)
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[node_def],
'test-model',
[X, pads, value],
[Y],
)
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example')
print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')
這個官方示例為我們演示了如何使用onnx.helper的make_tensor,make_tensor_value_info,make_attribute,make_node,make_graph,make_node等方法來完整構建了一個ONNX模型。需要注意的是在上面的例子中,輸入數據是一個一維Tensor,初始維度為[2],這也是為什么經過維度為[1,4]的Pad操作之后獲得的輸出Tensor維度為[3,4]。另外由于Pad操作是沒有帶任何權重信息的,所以當你打印ONNX模型時,ModelProto的GraphProto是沒有initializer這個屬性的。
0x5. onnx-simplifier
原本這里是要總結一些使用ONNX進行模型部署經常碰到一些因為版本兼容性,或者各種框架OP沒有對齊等原因導致的各種BUG。但是這樣會顯得文章很長,所以這里以一個經典的Pytorch轉ONNX的reshape問題為例子,來嘗試講解一下大老師的onnx-simplifier是怎么處理的,個人認為這個問題是基于ONNX進行模型部署最經典的問題。希望在解決這個問題的過程中大家能有所收獲。
問題發(fā)生在當我們想把下面這段代碼導出ONNX模型時:
import torch
class JustReshape(torch.nn.Module):
def __init__(self):
super(JustReshape, self).__init__()
def forward(self, x):
return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2]))
net = JustReshape()
model_name = 'just_reshape.onnx'
dummy_input = torch.randn(2, 3, 4, 5)
torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])
由于這個模型輸入維度是固定的,所以我們期望模型是這樣的:

但是,即使使用了ONNX的polished工具也只能獲得下面的模型:

要解決這個問題,有兩種方法,第一種是做一個強制類型轉換,將x.shape[0]類似的變量強制轉換為常量即int(x.shape[0]),或者使用大老師的onnx-simplifer來解決這一問題。
之前一直好奇onnx-simplifer是怎么做的,最近對ONNX有了一些理解之后也能逐步看懂做法了。我來嘗試解釋一下。onnx-simplifer的核心思路就是利用onnxruntime推斷一遍ONNX的計算圖,然后使用常量輸出替代冗余的運算OP。主體代碼為:
def simplify(model: Union[str, onnx.ModelProto], check_n: int = 0, perform_optimization: bool = True,
skip_fuse_bn: bool = False, input_shapes: Optional[TensorShapes] = None, skipped_optimizers: Optional[Sequence[str]] = None, skip_shape_inference=False) \
-> Tuple[onnx.ModelProto, bool]:
if input_shapes is None:
input_shapes = {}
if type(model) == str:
# 加載ONNX模型
model = onnx.load(model)
# 檢查ONNX模型格式是否正確,圖結構是否完整,節(jié)點是否正確等
onnx.checker.check_model(model)
# 深拷貝一份原始ONNX模型
model_ori = copy.deepcopy(model)
if not skip_shape_inference:
# 獲取ONNX模型中特征圖的尺寸
model = infer_shapes(model)
input_shapes = check_and_update_input_shapes(model, input_shapes)
if perform_optimization:
model = optimize(model, skip_fuse_bn, skipped_optimizers)
const_nodes = get_constant_nodes(model)
res = forward_for_node_outputs(
model, const_nodes, input_shapes=input_shapes)
const_nodes = clean_constant_nodes(const_nodes, res)
model = eliminate_const_nodes(model, const_nodes, res)
onnx.checker.check_model(model)
if not skip_shape_inference:
model = infer_shapes(model)
if perform_optimization:
model = optimize(model, skip_fuse_bn, skipped_optimizers)
check_ok = check(model_ori, model, check_n, input_shapes=input_shapes)
return model, check_ok
上面有一行:model = infer_shapes(model) 是獲取ONNX模型中特征圖的尺寸,它的具體實現如下:
def infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto:
try:
model = onnx.shape_inference.infer_shapes(model)
except:
pass
return model
我們保存一下調用了這個接口之后的ONNX模型,并將其可視化看一下:

相對于原始的ONNX模型,現在每一條線都新增了一個shape信息,代表它的前一個特征圖的shape是怎樣的。
接著,程序使用到了check_and_update_input_shapes接口,這個接口的代碼示例如下,它可以用來判斷輸入的格式是否正確以及輸入模型是否存在所有的指定輸入節(jié)點。
def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: TensorShapes) -> TensorShapes:
input_names = get_input_names(model)
if None in input_shapes:
if len(input_names) == 1:
input_shapes[input_names[0]] = input_shapes[None]
del input_shapes[None]
else:
raise RuntimeError(
'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape')
for x in input_shapes:
if x not in input_names:
raise RuntimeError(
'The model doesn\'t have input named "{}"'.format(x))
return input_shapes
在這個例子中,如果我們指定input_shapes為:{'input': [2, 3, 4, 5]},那么這個函數的輸出也為{'input': [2, 3, 4, 5]}。如果不指定,輸出就是{}。驗證這個函數的調用代碼如下所示:

確定了輸入沒有問題之后,程序會根據用戶指定是否優(yōu)化ONNX模型進入優(yōu)化函數,函數定義如下:
def optimize(model: onnx.ModelProto, skip_fuse_bn: bool, skipped_optimizers: Optional[Sequence[str]]) -> onnx.ModelProto:
"""
:model參數: 待優(yōu)化的ONXX模型.
:return: 優(yōu)化之后的ONNX模型.
簡化之前, 使用這個方法產生會在'forward_all'用到的ValueInfo
簡化之后,使用這個方法去折疊前一步產生的常量到initializer中并且消除沒被使用的常量
"""
onnx.checker.check_model(model)
onnx.helper.strip_doc_string(model)
optimizers_list = [
'eliminate_deadend',
'eliminate_nop_dropout',
'eliminate_nop_cast',
'eliminate_nop_monotone_argmax', 'eliminate_nop_pad',
'extract_constant_to_initializer', 'eliminate_unused_initializer',
'eliminate_nop_transpose',
'eliminate_nop_flatten', 'eliminate_identity',
'fuse_add_bias_into_conv',
'fuse_consecutive_concats',
'fuse_consecutive_log_softmax',
'fuse_consecutive_reduce_unsqueeze', 'fuse_consecutive_squeezes',
'fuse_consecutive_transposes', 'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv', 'fuse_transpose_into_gemm', 'eliminate_duplicate_initializer'
]
if not skip_fuse_bn:
optimizers_list.append('fuse_bn_into_conv')
if skipped_optimizers is not None:
for opt in skipped_optimizers:
try:
optimizers_list.remove(opt)
except ValueError:
pass
model = onnxoptimizer.optimize(model, optimizers_list,
fixed_point=True)
onnx.checker.check_model(model)
return model
這個函數的功能是對原始的ONNX模型做一些圖優(yōu)化工作,比如merge_bn,fuse_add_bias_into_conv等等。我們使用onnx.save保存一下這個例子中圖優(yōu)化后的模型,可以發(fā)現它和優(yōu)化前的可視化效果是一樣的,如下圖所示:

這是因為在這個模型中是沒有上面列舉到的那些可以做圖優(yōu)化的情況,但是當我們打印一下ONNX模型我們會發(fā)現optimize過后的ONNX模型多出一些initializer數組:

這些數組存儲的就是這個圖中那些常量OP的具體值,通過這個處理我們就可以調用get_constant_nodes函數來獲取ONNX模型的常量OP了,這個函數的詳細解釋如下:
def get_constant_nodes(m: onnx.ModelProto) -> List[onnx.NodeProto]:
const_nodes = []
# 如果節(jié)點的name在ONNX的GraphProto的initizlizer數組里面,它就是靜態(tài)的tensor
const_tensors = [x.name for x in m.graph.initializer]
# 顯示的常量OP也加進來
const_tensors.extend([node.output[0]
for node in m.graph.node if node.op_type == 'Constant'])
# 一些節(jié)點的輸出shape是由輸入節(jié)點決定的,我們認為這個節(jié)點的輸出shape并不是常量,
# 所以我們不需要簡化這種節(jié)點
dynamic_tensors = []
# 判斷是否為動態(tài)OP
def is_dynamic(node):
if node.op_type in ['NonMaxSuppression', 'NonZero', 'Unique'] and node.input[0] not in const_tensors:
return True
if node.op_type in ['Reshape', 'Expand', 'Upsample', 'ConstantOfShape'] and len(node.input) > 1 and node.input[1] not in const_tensors:
return True
if node.op_type in ['Resize'] and ((len(node.input) > 2 and node.input[2] not in const_tensors) or (len(node.input) > 3 and node.input[3] not in const_tensors)):
return True
return False
for node in m.graph.node:
if any(x in dynamic_tensors for x in node.input):
dynamic_tensors.extend(node.output)
elif node.op_type == 'Shape':
const_nodes.append(node)
const_tensors.extend(node.output)
elif is_dynamic(node):
dynamic_tensors.extend(node.output)
elif all([x in const_tensors for x in node.input]):
const_nodes.append(node)
const_tensors.extend(node.output)
# 深拷貝
return copy.deepcopy(const_nodes)
在這個例子中,我們打印一下執(zhí)行這個獲取常量OP函數之后,Graph中有哪些OP被看成了常量OP。

獲取了模型中所有的常量OP之后,我們需要把所有的靜態(tài)節(jié)點擴展到ONNX Graph的輸出節(jié)點列表中,然后利用onnxruntme執(zhí)行一次forward:
def forward_for_node_outputs(model: onnx.ModelProto, nodes: List[onnx.NodeProto],
input_shapes: Optional[TensorShapes] = None) -> Dict[str, np.ndarray]:
if input_shapes is None:
input_shapes = {}
model = copy.deepcopy(model)
# nodes 是Graph中所有的靜態(tài)OP
add_features_to_output(model, nodes)
res = forward(model, input_shapes=input_shapes)
return res
其中add_features_to_output的定義如下:
def add_features_to_output(m: onnx.ModelProto, nodes: List[onnx.NodeProto]) -> None:
"""
Add features to output in pb, so that ONNX Runtime will output them.
:param m: the model that will be run in ONNX Runtime
:param nodes: nodes whose outputs will be added into the graph outputs
"""
# ONNX模型的graph擴展輸出節(jié)點,獲取所有靜態(tài)OP的輸出和原始輸出節(jié)點的輸出
for node in nodes:
for output in node.output:
m.graph.output.extend([onnx.ValueInfoProto(name=output)])
最后的forward函數就是利用onnxruntime推理獲得我們指定的輸出節(jié)點的值。這個函數這里不進行解釋。推理完成之后,進入下一個函數clean_constant_nodes,這個函數的定義如下:
def clean_constant_nodes(const_nodes: List[onnx.NodeProto], res: Dict[str, np.ndarray]):
"""
It seems not needed since commit 6f2a72, but maybe it still prevents some unknown bug
:param const_nodes: const nodes detected by `get_constant_nodes`
:param res: The dict containing all tensors, got by `forward_all`
:return: The constant nodes which have an output in res
"""
return [node for node in const_nodes if node.output[0] in res]
這個函數是用來清洗那些沒有被onnxruntime推理的靜態(tài)節(jié)點,但通過上面的optimize邏輯,我們的graph中其實已經不存在這個情況了(沒有被onnxruntime推理的靜態(tài)節(jié)點在圖優(yōu)化階段會被優(yōu)化掉),因此這個函數理論上是可以刪除的。這個地方是為了避免刪除掉有可能引發(fā)其它問題就保留了。
不過從一些實際經驗來看,還是保留吧,畢竟不能保證ONNX的圖優(yōu)化就完全正確,前段時間剛發(fā)現了TensorRT圖優(yōu)化出了一個BUG。保留這個函數可以提升一些程序的穩(wěn)定性。

接下來就是這個onnx-simplifier最核心的步驟了,即將常量節(jié)點從原始的ONNX Graph中移除,函數接口為eliminate_const_nodes:
def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: List[onnx.NodeProto],
res: Dict[str, np.ndarray]) -> onnx.ModelProto:
"""
:model參數: 原始ONNX模型
:const_nodes參數: 使用`get_constant_nodes`獲得的靜態(tài)OP
:res參數: 包含所有輸出Tensor的字典
:return: 簡化后的模型. 所有冗余操作都已刪除.
"""
for i, node in enumerate(model.graph.node):
if node in const_nodes:
for output in node.output:
new_node = copy.deepcopy(node)
new_node.name = "node_" + output
new_node.op_type = 'Constant'
new_attr = onnx.helper.make_attribute(
'value',
onnx.numpy_helper.from_array(res[output], name=output)
)
del new_node.input[:]
del new_node.attribute[:]
del new_node.output[:]
new_node.output.extend([output])
new_node.attribute.extend([new_attr])
insert_elem(model.graph.node, i + 1, new_node)
del model.graph.node[i]
return model
運行這個函數之后我們獲得的ONNX模型可視化結果是這樣子的:

注意,這里獲得的ONNX模型中雖然常量節(jié)點已經從Graph中斷開了,即相當于這個DAG里面多了一些單獨的點,但是這些點還是存在的。因此,我們再執(zhí)行一次optimize就可以獲得最終簡化后的ONNX模型了。最終簡化后的ONNX模型如下圖所示:

0x6. 總結
介于篇幅原因,介紹ONNX的第一篇文章就介紹到這里了,后續(xù)可能會結合更多實踐的經驗來談談ONNX了,例如OneFlow模型導出ONNX進行部署???傊?,文章很長,謝謝你的觀看,希望這篇文章有幫助到你。最后歡迎star大老師的onnx-simplifier。
0x7. 參考資料
【1】https://zhuanlan.zhihu.com/p/86867138 【2】https://oldpan.me/archives/talk-about-onnx 【3】https://blog.csdn.net/chengzi_comm/article/details/53199278 【4】https://www.jianshu.com/p/a24c88c0526a 【5】https://bindog.github.io/blog/2020/03/13/deep-learning-model-convert-and-depoly/ 【6】 https://github.com/daquexian/onnx-simplifier
歡迎關注GiantPandaCV, 在這里你將看到獨家的深度學習分享,堅持原創(chuàng),每天分享我們學習到的新鮮知識。( ? ?ω?? )?
有對文章相關的問題,或者想要加入交流群,歡迎添加BBuf微信:
為了方便讀者獲取資料以及我們公眾號的作者發(fā)布一些Github工程的更新,我們成立了一個QQ群,二維碼如下,感興趣可以加入。
