模型部署翻車(chē)記:pytorch轉(zhuǎn)onnx踩坑實(shí)錄

極市導(dǎo)讀
本文記錄了作者在深度學(xué)習(xí)模型部署是,從pytorch轉(zhuǎn)換onnx的過(guò)程中的踩坑記錄。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿
在深度學(xué)習(xí)模型部署時(shí),從pytorch轉(zhuǎn)換onnx的過(guò)程中,踩了一些坑。本文總結(jié)了這些踩坑記錄,希望可以幫助其他人。
首先,簡(jiǎn)單說(shuō)明一下pytorch轉(zhuǎn)onnx的意義。在pytorch訓(xùn)練出一個(gè)深度學(xué)習(xí)模型后,需要在TensorRT或者openvino部署,這時(shí)需要先把Pytorch模型轉(zhuǎn)換到onnx模型之后再做其它轉(zhuǎn)換。因此,在使用pytorch訓(xùn)練深度學(xué)習(xí)模型完成后,在TensorRT或者openvino或者opencv和onnxruntime部署時(shí),pytorch模型轉(zhuǎn)onnx這一步是必不可少的。接下來(lái)通過(guò)幾個(gè)實(shí)例程序,介紹pytorch轉(zhuǎn)換onnx的過(guò)程中遇到的坑。
1. opencv里的深度學(xué)習(xí)模塊不支持3維池化層
起初,我在微信公眾號(hào)里看到一篇文章《使用Python和YOLO檢測(cè)車(chē)牌》。文中展示的檢測(cè)結(jié)果如下,其實(shí)這種檢測(cè)結(jié)果并不是一個(gè)優(yōu)良的結(jié)果,可以看到檢測(cè)框里的車(chē)牌是傾斜的,如果要識(shí)別車(chē)牌里的文字,那么傾斜的車(chē)牌會(huì)嚴(yán)重影響車(chē)牌識(shí)別結(jié)果的。

對(duì)于車(chē)牌識(shí)別這種場(chǎng)景,在做車(chē)牌檢測(cè)時(shí),一種優(yōu)良的檢測(cè)結(jié)果應(yīng)該是這樣的,如下圖所示。

在輸出車(chē)牌檢測(cè)框的同時(shí)輸出檢測(cè)到的車(chē)牌的4個(gè)角點(diǎn)。有了這4個(gè)角點(diǎn)之后,對(duì)車(chē)牌做透視變換,這時(shí)的車(chē)牌就是水平放置的,最后做車(chē)牌識(shí)別,這樣就做成了一個(gè)車(chē)牌識(shí)別系統(tǒng),在這個(gè)系統(tǒng)里包含車(chē)牌檢測(cè),車(chē)牌矯正,車(chē)牌識(shí)別三個(gè)模塊。車(chē)牌檢測(cè)模塊使用retinaface,原始的retinaface是做人臉檢測(cè)的,它能輸出人臉檢測(cè)矩形框和人臉5個(gè)關(guān)鍵點(diǎn)。考慮到車(chē)牌只有4個(gè)點(diǎn),于是修改retinaface的網(wǎng)絡(luò)結(jié)構(gòu)使其輸出4個(gè)關(guān)鍵點(diǎn),然后在車(chē)牌數(shù)據(jù)集訓(xùn)練,訓(xùn)練完成后,以一幅圖片上做目標(biāo)檢測(cè)的結(jié)果如上圖所示。車(chē)牌矯正模塊使用了傳統(tǒng)圖像處理方法,關(guān)鍵函數(shù)是opencv里的getPerspectiveTransform和warpPerspective。車(chē)牌識(shí)別模塊使用Intel公司提出的LPRNet。
整套程序是基于pytorch框架運(yùn)行的,我把這套程序發(fā)布在github上,地址是 https://github.com/hpc203/license-plate-detect-recoginition-pytorch
接下來(lái)我就嘗試把pytorch模型轉(zhuǎn)換到onnx文件,然后使用opencv做車(chē)牌檢測(cè)與識(shí)別。然而在轉(zhuǎn)換完成onnx文件后,使用opencv讀取onnx文件遇到了一些坑,我在網(wǎng)上搜索,也沒(méi)有找到解決辦法。
轉(zhuǎn)換過(guò)程分兩步,首先是轉(zhuǎn)換車(chē)牌檢測(cè)retinaface到onnx文件,這一步倒是很順利,轉(zhuǎn)換沒(méi)有出錯(cuò),并且使用opencv讀取onnx文件做前向推理的輸出結(jié)果也是正確的。第二步轉(zhuǎn)換車(chē)牌識(shí)別LPRNet到onnx文件,由于Pytorch自帶torch.onnx.export轉(zhuǎn)換得到的ONNX,因此轉(zhuǎn)換的代碼很簡(jiǎn)單,在生成onnx文件后,opencv讀取onnx文件出現(xiàn)了模型其妙的錯(cuò)誤。程序運(yùn)行的結(jié)果截圖如下

從打印結(jié)果看,torch.onnx.export生成onnx文件時(shí)沒(méi)有問(wèn)題的,但是在cv2.dnn.readNet這一步出現(xiàn)異常導(dǎo)致程序中斷,并且打印出的異常信息是一連串的數(shù)字,去百度搜索也么找到解決辦法。觀察LPRNet的網(wǎng)絡(luò)結(jié)構(gòu),發(fā)現(xiàn)在LPRNet里定義了3維池化層,代碼截圖如下



可以看到依然出錯(cuò),這說(shuō)明opencv的深度學(xué)習(xí)模塊里不支持3維池化。不過(guò),對(duì)比3維池化和2維池化的前向計(jì)算原理可以發(fā)現(xiàn),3維池化其實(shí)等價(jià)于2個(gè)2維池化。程序?qū)嵗缦?/p>
程序最后最后運(yùn)行結(jié)果打印信息是相等。從這里就可以看出opencv里的深度學(xué)習(xí)模塊并不支持3維池化的前向計(jì)算,這期待后續(xù)新版本的opencv里能添加3維池化的計(jì)算。這時(shí)在LPRNet網(wǎng)絡(luò)結(jié)構(gòu)定義文件里修改3維池化層,重新生成onnx文件,opencv讀取onnx文件執(zhí)行前向計(jì)算后依然出錯(cuò),運(yùn)行結(jié)果如下。

于是繼續(xù)觀察LPRNet的網(wǎng)絡(luò)結(jié)構(gòu),在forward函數(shù)里看到有求平均值的操作,代碼截圖如下所示

注意到第一個(gè)torch.mean函數(shù)里沒(méi)有聲明在哪個(gè)維度求平均值,這說(shuō)明它是對(duì)一個(gè)4維四維張量的整體求平均值,這時(shí)候從一個(gè)4維空間搜索成一個(gè)點(diǎn),也就是一個(gè)標(biāo)量數(shù)值。但是在pytorch里,對(duì)一個(gè)張量求平均值后依然是一個(gè)張量,只不過(guò)它的維度shape是空的,示例代碼如下。這時(shí)如果想要訪問(wèn)平均值,需要加上.item(),這個(gè)是需要注意的一個(gè)pytorch知識(shí)點(diǎn)。

在修改這個(gè)代碼bug后重新生成onnx文件,使用opencv讀取onnx文件做前向計(jì)算就不再出現(xiàn)異常錯(cuò)誤了。
通過(guò)以上幾個(gè)程序?qū)嶒?yàn),可以總結(jié)出opencv讀取onnx文件做深度學(xué)習(xí)前向計(jì)算的2個(gè)坑:
(1) .opencv里的深度學(xué)習(xí)模塊不支持3維池化計(jì)算,解決辦法是修改原始網(wǎng)絡(luò)結(jié)構(gòu),把3維池化轉(zhuǎn)換成兩個(gè)2維池化,重新生成onnx文件
(2) .當(dāng)神經(jīng)網(wǎng)絡(luò)里有torch.mean和torch.sum這種把4維張量收縮到一個(gè)數(shù)值的運(yùn)算時(shí),opencv執(zhí)行forward會(huì)出錯(cuò),這時(shí)的解決辦法是修改原始網(wǎng)絡(luò)結(jié)構(gòu),在torch.mean的后面加上.item()
在解決這些坑之后,編寫(xiě)了一套使用opencv做車(chē)牌檢測(cè)與識(shí)別的程序,包含C++和python兩個(gè)版本的代碼。使用opencv的dnn模塊做前向計(jì)算,后處理模塊是自己使用C++和Python獨(dú)立編寫(xiě)的。
代碼已發(fā)布在github上,地址是:https://github.com/hpc203/license-plate-detect-recoginition-opencv
2. opencv與onnxruntime的差異
起初在github上看到一個(gè)使用DBNet檢測(cè)條形碼的程序,不過(guò)它是基于pytorch框架做的。于是我編寫(xiě)一套程序把pytorch模型轉(zhuǎn)換到onnx文件,使用opencv讀取onnx文件做前向計(jì)算。編寫(xiě)完程序后在運(yùn)行時(shí)沒(méi)有出錯(cuò),但是最后輸出的結(jié)果跟調(diào)用pytorch 的輸出結(jié)果不一致,并且從可視化結(jié)果看,沒(méi)有檢測(cè)出圖片中的條形碼。這時(shí)在看到網(wǎng)上有很多使用onnxruntime部署onnx模型的文章,于是決定使用onnxruntime部署,編寫(xiě)完程序后運(yùn)行,選取幾張快遞單圖片測(cè)試,結(jié)果如下圖所示DBNet檢測(cè)到的4個(gè)點(diǎn),圖中綠色的點(diǎn),紅色的線是把4個(gè)連接起來(lái)的直線。


并且我還編寫(xiě)了一個(gè)函數(shù)比較opencv和onnxruntime的輸出結(jié)果,程序代碼和運(yùn)行結(jié)果如下,可以看到在相同輸入,讀取同一個(gè)onnx文件的前提下,opencv和onnxruntime的輸出結(jié)果竟然不相同。

ONNXRuntime是微軟推出的一款推理框架,用戶可以非常便利的用其運(yùn)行一個(gè)onnx模型。從這個(gè)實(shí)驗(yàn),可以看出相比于opencv庫(kù),onnxruntime庫(kù)對(duì)onnx模型支持的更好。
我把這套使用DBNet檢測(cè)條形碼的程序發(fā)布在github上,地址是:https://github.com/hpc203/dbnet-barcode
3. onnxruntime支持3維池化和3維卷積
在第1節(jié)講到opencv不支持3維池化,那么onnxruntime是否支持呢?接著編寫(xiě)了一個(gè)程序探索onnxruntime對(duì)3維池化的支持情況,代碼和運(yùn)行結(jié)果如下,可以看到程序報(bào)錯(cuò)了。

查看nn.MaxPool3d的說(shuō)明文檔,截圖如下,可以看到它的輸入和輸出是5維張量,于是修改上面的代碼,把輸入調(diào)整到5維張量。

代碼和運(yùn)行結(jié)果如下,可以看到這時(shí)候onnxruntime庫(kù)能正常讀取onnx文件,并且它的輸出結(jié)果跟pytorch的輸出結(jié)果相等。

繼續(xù)實(shí)驗(yàn),把三維池化改作三維卷積,代碼和運(yùn)行結(jié)果如下,可以看到平均差異在小數(shù)點(diǎn)后11位,可以忽略不計(jì)。

在第1節(jié)講到過(guò)opencv不支持3維池化,那時(shí)候的輸入張量是4維的,如果把輸入張量改成5維的,那么opencv是否就能進(jìn)行3維池化計(jì)算呢?為此,編寫(xiě)代碼,驗(yàn)證這個(gè)想法。代碼和運(yùn)行結(jié)果如下,可以看到在cv2.dnn.blobFromImage這行代碼出錯(cuò)了。

查看cv2.dnn.blobFromImage這個(gè)函數(shù)的說(shuō)明文檔,截圖如下,可以看到它的輸入image是4維的,這說(shuō)明它不支持5維的輸入。
經(jīng)過(guò)這一系列的程序?qū)嶒?yàn)論證,可以看出onnxruntime庫(kù)對(duì)onnx模型支持的更好。如果深度學(xué)習(xí)模型有3維池化或3維卷積層,那么在轉(zhuǎn)換到onnx文件后,使用onnxruntime部署深度學(xué)習(xí)是一個(gè)不錯(cuò)的選擇。
4. onnx動(dòng)態(tài)分辨率輸入
不過(guò)我在做pytorch導(dǎo)出onnx文件時(shí),還發(fā)現(xiàn)了一個(gè)問(wèn)題。在torch.export函數(shù)里有一個(gè)輸入?yún)?shù)dynamic_axes,它表示動(dòng)態(tài)的軸,即可變的維度。假如一個(gè)神經(jīng)網(wǎng)絡(luò)輸入是動(dòng)態(tài)分辨率的,那么需要定義dynamic_axes = {'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}},接下來(lái)我編寫(xiě)一個(gè)程序來(lái)驗(yàn)證,代碼和運(yùn)行結(jié)果的截圖如下

可以看到,在生成onnx文件后,使用onnxruntime庫(kù)讀取,對(duì)輸入blob的高增加10個(gè)像素單位,在run這一步出錯(cuò)了。使用opencv讀取onnx文件,代碼和運(yùn)行結(jié)果的截圖如下,可以看到依然出錯(cuò)了。

通過(guò)這個(gè)程序?qū)嶒?yàn),讓人懷疑torch.export函數(shù)的輸入?yún)?shù)dynamic_axes是否真的支持動(dòng)態(tài)分辨率輸入的。
以上這些程序?qū)嶒?yàn)是我在編寫(xiě)算法應(yīng)用程序時(shí)記錄下的一些bug和解決方案的,希望能幫助到深度學(xué)習(xí)算法開(kāi)發(fā)應(yīng)用人員少走彎路。
此外,DBNet的官方代碼里提供了轉(zhuǎn)換到onnx模型文件,于是我依然編寫(xiě)了一套使用opencv部署DBNet文字檢測(cè)的程序,依然是包含C++和Python兩個(gè)版本的代碼。官方代碼的模型是在ICDAR場(chǎng)景文本檢測(cè)數(shù)據(jù)集上訓(xùn)練的,考慮到車(chē)牌里也含有文字,我把文章開(kāi)頭展示的汽車(chē)圖片作為輸入,程序檢測(cè)結(jié)果如下,可以看到依然能檢測(cè)到車(chē)牌的4個(gè)角點(diǎn),只是不夠準(zhǔn)確。如果想要獲得準(zhǔn)確的角點(diǎn)定位,可以在車(chē)牌數(shù)據(jù)集上訓(xùn)練DBNet。
我把使用opencv部署DBNet文字檢測(cè)的程序發(fā)布在github上,程序依然是包含c++和python兩種版本的實(shí)現(xiàn),地址是:https://github.com/hpc203/dbnet-opencv-cpp-python
推薦閱讀
2021-01-17
2020-12-06
2020-11-06

# 極市原創(chuàng)作者激勵(lì)計(jì)劃 #

