Pytorch轉(zhuǎn)ONNX-實(shí)戰(zhàn)篇2(實(shí)戰(zhàn)踩坑總結(jié))
各位讀者好,這里是BBuf,由于近期公眾號(hào)遭受惡意舉報(bào)將會(huì)在很長(zhǎng)一段時(shí)間內(nèi)失去原創(chuàng)功能并且面臨封號(hào),所以我們作者團(tuán)隊(duì)商量著重新申請(qǐng)了一個(gè)新的公眾號(hào),名字是?『PandaCV』 ,長(zhǎng)按下方二維碼關(guān)注和轉(zhuǎn)發(fā),謝謝!這個(gè)公眾號(hào)近期會(huì)通過(guò)無(wú)來(lái)源轉(zhuǎn)載的方式將GiantPandaCV公眾號(hào)的所有高質(zhì)量文章(也會(huì)舍棄一部分不好的)逐漸搬運(yùn)過(guò)去(會(huì)耗時(shí)半個(gè)月到1個(gè)月),然后在PandaCV這個(gè)公眾號(hào)上繼續(xù)發(fā)表高質(zhì)量原創(chuàng)文章。維護(hù)和諧健康的知識(shí)原創(chuàng)環(huán)境是每個(gè)人的責(zé)任,希望我們能一起努力,打造一個(gè)學(xué)習(xí)和分享的雙贏平臺(tái)。
作者丨立交橋跳水冠軍
前兩篇文章分別從理論和ONNX的核心機(jī)制描述了Pytorch轉(zhuǎn)ONNX需要注意的事情。接下來(lái)這篇文章沒(méi)有什么核心主旨,只是純粹記錄我當(dāng)時(shí)做項(xiàng)目的時(shí)候踩的坑以及應(yīng)對(duì)方案
(1)Pytorch2ONNX不支持對(duì)slice對(duì)象賦值
下面這段代碼是不被Pytorch原生的onnx轉(zhuǎn)換接口支持的,即不能對(duì)slice對(duì)象賦值
preds[:,?:,?y1:y2,?x1:x2]?+=?crop_seg_logit
仔細(xì)想想其實(shí)也比較合理,因?yàn)樯厦娴牟僮饕埠茈y在DAG上被表示,因?yàn)椴⒉粌H僅是把preds中的那個(gè)區(qū)域取出來(lái)弄個(gè)新的變量,然后在上面+1,而是直接把preds的一部分改掉了。當(dāng)時(shí)我負(fù)責(zé)MMSeg的slide inference轉(zhuǎn)換的時(shí)候遇到了這個(gè)問(wèn)題,解決方案如下:
preds?+=?F.pad(crop_seg_logit,
???????????????(int(x1),?int(preds.shape[3]?-?x2),?int(y1),
????????????????int(preds.shape[2]?-?y2)))
即我對(duì)crop_seg_logit做了一個(gè)padding,把它變成了和preds一樣的大小,這樣我就直接變成了矩陣相加,沒(méi)必要變成slice的操作了
這個(gè)方法自然很丑,而且會(huì)引出一個(gè)新的問(wèn)題,那就是Pytorch生成的onnx padding的格式,onnx runtime接收的格式以及TensorRT需要的格式都不一樣。這個(gè)就是之后的問(wèn)題了(超綱了,不講了)
這里具體的例子我懶得查了,以二維矩陣的填充為例。只記得一個(gè)轉(zhuǎn)出來(lái)的是(begin0, begin1, end0, end1),另一個(gè)是(begin0, end0, begin1, end1)
這里面begin0代表第0維左邊的填充數(shù)量,end0代表右邊的填充數(shù)量
(2)resize
當(dāng)時(shí)做segmentation模型的時(shí)候,最重要的就是resize操作。ONNX里面的resize要求output shape必須為常量(即tuple of int),因此不可以用tensor.Size作為輸入,因?yàn)槿思也⒉皇莟uple of int
if?isinstance(size,?torch.Size):
????size?=?tuple(int(x)?for?x?in?size)
所以我們必須手動(dòng)粗暴的把torch.Size變成tuple of int
當(dāng)時(shí)有reviewer吐槽我這個(gè)方法丑,要我改成tuple(size),說(shuō)Pytorch重載了tuple,直接可以把torch.Size變成tuple of int。但是很詭異的是在正常情況下的確可以,但如果一旦進(jìn)入了ONNX tracining模式,這個(gè)方法就失效了。我簡(jiǎn)單看了看,推測(cè)是因?yàn)閷?duì)tuple的重載是在C++層面做的,而ONNX tracing也會(huì)涉及到一些C++層面的事情,也就是說(shuō)ONNX tracing會(huì)重載一些C++的部分,可能正好就把tuple給抹掉了
(3) 應(yīng)對(duì)kwargs的約束
pytorch自帶的onnx轉(zhuǎn)換api: torch.onnx.export,只支持args參數(shù)。一般來(lái)說(shuō)調(diào)用這個(gè)api只需要提供model(喜聞樂(lè)見(jiàn)的nn.Module),調(diào)用model的參數(shù)args(也就是調(diào)用model.forwrd()的參數(shù))以及導(dǎo)出的文件名f。然后這個(gè)函數(shù)就會(huì)內(nèi)部執(zhí)行一遍: model(*args),執(zhí)行的時(shí)候做tracining

但是我們知道一般來(lái)說(shuō)除了args,還需要kwargs,比如model(input, getloss=False),其中input就是args,F(xiàn)alse就是kwargs。OpenMMLab里面幾乎所有的model都需要kwargs
為了繞開(kāi)這個(gè)約束,我們需要利用python的partial函數(shù),將model做個(gè)封裝:
?model.forward?=?partial(model.forward,?return_loss=False)
這樣我們可以給model提供需要的kwargs,同時(shí)又可以原封不動(dòng)的調(diào)用torch.onnx.export
注意,kwargs不能包括網(wǎng)絡(luò)的輸入,比如如果你想把input image放進(jìn)args,那么得到的onnx就會(huì)是一個(gè)沒(méi)有輸入的圖(它會(huì)把kwargs里面的input image當(dāng)成一個(gè)常量)
(4)Pytorch和ONNX Runtime結(jié)果對(duì)齊
OpenMMLab系列提供了一個(gè)很有用的功能,就是自動(dòng)比對(duì)Pytorch和ONNXRuntime的精度。這個(gè)功能可以幫助用戶確定轉(zhuǎn)出來(lái)的ONNX有沒(méi)有問(wèn)題。
然而之前也提到過(guò),ONNXRuntime和Pytorch需要的ONNX格式不一樣,而且有些計(jì)算也不一樣,因此就算結(jié)果對(duì)不上,也不能代表什么
在某些操作上,ONNXRuntime和Pytorch的行為不一致。比如對(duì)一個(gè)一維tensor:[0,0,0]調(diào)用argmax,那么ONNXRuntime返回的是0,而Pytorch是1(舉個(gè)例子,具體的差異我記不清了)
當(dāng)時(shí)我在做Detection模型的自動(dòng)比對(duì)的時(shí)候就遇到了問(wèn)題,在經(jīng)歷了nms操作之后,bbox會(huì)根據(jù)score的大小做排序,但score相同的情況下,ONNXRuntime和Pytorch的結(jié)果就會(huì)有差異。因此我們最后只選擇比對(duì)score,而不管bbox的dx,dy這些信息了
- The End -
長(zhǎng)按二維碼關(guān)注我們
本公眾號(hào)專注:
1. 技術(shù)分享;
2.?學(xué)術(shù)交流;
3.?資料共享。
歡迎關(guān)注我們,一起成長(zhǎng)!
