<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          Pytorch轉(zhuǎn)ONNX-實戰(zhàn)篇1(tracing機制)

          共 3460字,需瀏覽 7分鐘

           ·

          2021-01-17 10:57

          作者丨立交橋跳水冠軍
          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/273566106
          編輯丨GiantPandaCV

          昨天的文章簡單描述了在Pytorch轉(zhuǎn)ONNX中面臨的問題和需要注意的事情,今天的文章會重點結(jié)合OpenMMlab系列中用到的Pytorch轉(zhuǎn)ONNX的小技巧來介紹實戰(zhàn)部分。

          (1)tracing的機制

          上文提到過,Pytorch轉(zhuǎn)ONNX的方式是基于tracing(追蹤),通俗來說,就是ONNX的相關代碼在一旁看著Pytorch跑一遍,運行了什么內(nèi)容就把什么記錄下來。但是在這里并不是所有Python的運行內(nèi)容都會被記錄。舉個例子,下面的代碼中,

          c?=?torch.matmul(a,?b)
          print("Blabla")
          e?=?torch.matmul(c,?d)

          其中只有第1,3行相關的內(nèi)容會被記錄,因為只有他們是和Pytorch相關的,而第二行只是普通的python語句。

          具體來說,只有ATen操作會被記錄下來。ATen可以被理解為一個Pytorch的基本操作庫,一切的Pytorch函數(shù)都是基于這些零部件構(gòu)造出來的(比如ATen就是加減乘除,所有Pytorch的其他操作,比如平方,算sigmoid,都可以根據(jù)加減乘除構(gòu)造出來)

          *之前說的ONNX無法記錄if語句的問題也是因為if并不是Aten中的操作

          雖然ONNX可以記錄所有Pytorch的執(zhí)行(即記錄所有ATen操作),但是在輸出的時候會做一個剪枝,把沒用的操作剪掉

          舉個例子,下面的程序,顯而易見第一句話是沒有用的。


          t1?=?torch.matmul(a,?b)
          t2?=?torch.matmul(c,?d)
          return?t2

          ONNX會在得到全部的操作以及他們之間的輸入輸出關系后(以DAG作為表示),根據(jù)DAG的輸出往前推,做遍歷,所有可以被遍歷到的節(jié)點被保留,其他節(jié)點直接扔掉。

          在MMDetection(https://github.com/open-mmlab/mmdetection)中,在NMS(non-Maximumnon maximum suppression)中有如下代碼:

          if?bboxes.numel()?==?0:
          ????bboxes?=?multibboxes.newzeros((0,?5))
          ????labels?=?multibboxes.newzeros((0,?),?dtype=torch.long)

          ????if?torch.onnx.isinonnxexport():
          ????????raise?RuntimeError('[ONNX?Error]?Can?not?record?NMS?'
          ???????????????????????????'as?it?has?not?been?executed?this?time')
          ????return?bboxes,?labels

          dets,?keep?=?batchednms(bboxes,?scores,?labels,?nmscfg)

          代碼邏輯很簡單,如果之前的網(wǎng)絡根本沒有輸出任何合法的bbox(第一行的分支判斷),那么顯然nms的結(jié)果就是一堆0,所以沒必要運行nms直接返回0就可以。

          如果我們想將這段代碼轉(zhuǎn)換到ONNX,之前我們提到過ONNX不能處理分支邏輯,因此只能選擇一條路去走,記錄那條路轉(zhuǎn)換得到的模型。很顯然,正常情況下我們自然期待會有較多的bbox,并且將這些bbox作為參數(shù)調(diào)用nms。

          所以如果我們發(fā)現(xiàn)模型執(zhí)行的路徑觸發(fā)了if分支,我們必須要進行一個判斷,看看是不是在轉(zhuǎn)ONNX,如果是的話我們就需要直接報錯,因為顯然轉(zhuǎn)出來的ONNX不是我們想要的。

          假設什么都不做,在這種情況下我們轉(zhuǎn)出來的模型是什么樣呢?思考一下不難發(fā)現(xiàn),假設函數(shù)的返回值就是網(wǎng)絡的最終輸出,那么我們只會得到一個2個節(jié)點的DAG,即第2,3行的兩個操作。之前說過ONNX拿到所有的DAG之后會做剪枝,在這里ONNX拿到返回值(bboxes, labels)做回溯,發(fā)現(xiàn)最頭上就是第2,3行的兩個操作,就直接停掉了。所有其他的操作,比如backbone,rpn,fpn,都會被扔掉。

          因此,在進行MMDet模型的轉(zhuǎn)換的時候,必須用真實的數(shù)據(jù)和訓練好的參數(shù)來做轉(zhuǎn)換,否則基本不會得到有效的bbox,于是就會觸發(fā)第6行的error

          (2)利用tracing機制做優(yōu)化

          在MMSeg中有一個很巧妙的利用tracing機制做優(yōu)化的例子。

          在slide inference時,我們需要計算一個count mat矩陣,這個矩陣在h, w以及對應的stride都固定的情況下會是一個常量。

          不過在訓練時,往往這些都是我們要調(diào)的參數(shù),所有MMSeg沒有選擇把這些常數(shù)保存下來,而是每次都算一遍

          ????????countmat?=?img.newzeros((batchsize,?1,?himg,?wimg))
          ????????for?hidx?in?range(hgrids):
          ????????????for?widx?in?range(wgrids):
          ????????????????y1?=?hidx?*?hstride
          ????????????????x1?=?widx?*?wstride
          ????????????????y2?=?min(y1?+?hcrop,?himg)
          ????????????????x2?=?min(x1?+?wcrop,?wimg)
          ????????????????y1?=?max(y2?-?hcrop,?0)
          ????????????????x1?=?max(x2?-?wcrop,?0)
          ????????????????cropimg?=?img[:,?:,?y1:y2,?x1:x2]
          ????????????????cropseglogit?=?self.encodedecode(cropimg,?imgmeta)
          ????????????????preds?+=?F.pad(cropseglogit,
          ???????????????????????????????(int(x1),?int(preds.shape[3]?-?x2),?int(y1),
          ????????????????????????????????int(preds.shape[2]?-?y2)))

          ????????????????countmat[:,?:,?y1:y2,?x1:x2]?+=?1
          ????????assert?(countmat?==?0).sum()?==?0
          ????????if?torch.onnx.isinonnxexport():
          ????????????#?cast?countmat?to?constant?while?exporting?to?ONNX
          ????????????countmat?=?torch.fromnumpy(
          ????????????????countmat.cpu().detach().numpy()).to(device=img.device)

          不過在部署時,這些參數(shù)往往是固定的,因此我們沒必要把它算一遍。因此在倒數(shù)第4行的if分支里,我們做了一件看似很沒用的事

          countmat?=?torch.fromnumpy(countmat.cpu().detach().numpy()).to(device=img.device)

          即我們把算出來的countmat從tensor轉(zhuǎn)換成numpy,再轉(zhuǎn)回tensor。

          其實我們的目的是切斷tracing。

          之前提到過,ONNX只能記錄ATen相關的操作,但是很顯然,tensor和numpy的互轉(zhuǎn)肯定不是ATen操作。因此在回溯的時候,當訪問到count mat,ONNX并不能發(fā)現(xiàn)它是被誰運算出來的,所以countmat就會被看作一個常數(shù)被保存下來,之前計算countmat的部分都會被扔掉


          - The End -


          GiantPandaCV

          長按二維碼關注我們

          本公眾號專注:

          1. 技術分享;

          2.?學術交流

          3.?資料共享

          歡迎關注我們,一起成長!

          瀏覽 33
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  青春草在线视频 | 日本女人毛片全网推荐免费看 | 97国产精品视频 | 国产精品亚洲专区在线播放麻豆 | 任你操逼|