<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          pytorch煉丹,那些不為人知的細節(jié)

          共 7169字,需瀏覽 15分鐘

           ·

          2022-01-13 21:13

          ??

          作者丨Fatescript

          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/450779978;僅學(xué)術(shù)分享

          本文算是我工作一年多以來的一些想法和經(jīng)驗,最早發(fā)布在曠視研究院內(nèi)部的論壇中,本著開放和分享的精神發(fā)布在我的知乎專欄中,如果想看干貨的話可以直接跳過動機部分。另外,后續(xù)在這個專欄中,我會做一些關(guān)于原理和設(shè)計方面的一些分享,希望能給領(lǐng)域從業(yè)人員提供一些看待問題的不一樣的視角。

          動機

          前段時間走在路上,一直在思考一個問題:我的時間開銷很多都被拿去給別人解釋一些在我看起來顯而易見的問題了,比如( https://link.zhihu.com/?target=https%3A//github.com/Megvii- BaseDetection/cvpods )里面的一些code寫法問題(雖然這在某些方面說明了文檔建設(shè)的不完善),而這變相導(dǎo)致了我實際工作時間的減少,如何讓別人少問一些我覺得答案顯而易見的問題?如何讓別人提前規(guī)避一些不必要的坑?只有解決掉這樣的一些問題,我才能從一件件繁瑣的小事中解放出來,把精力放在我真正關(guān)心的事情上去。

          其實之前同事有跟我說過類似的話,每次帶一個新人,都要告訴他:你的實現(xiàn)需要注意這里blabla,還要注意那里blabla。說實話,我很佩服那些帶intern時候非常細致和知無不言的人,但我本性上并不喜歡每次花費時間去解釋一些我覺得顯而易見的問題,所以我寫下了這個帖子,把我踩過的坑和留下來的經(jīng)驗分享出去。希望能夠方便別人,同時也節(jié)約我的時間。

          加入曠視以來,個人一直在做一些關(guān)于框架相關(guān)的內(nèi)容,所以內(nèi)容主要偏向于模型訓(xùn)練之類的工作。因為 一個擁有知識的人是無法想象知識在別人腦海中的樣子的(the curse of knowledge),所以我只能選取被問的最多的,和我認為最應(yīng)該知道的

          準備好了的話,我們就啟航出發(fā)(另,這篇專欄文章會長期進行更新)。

          坑/經(jīng)驗

          Data模塊

          1. python圖像處理用的最多的兩個庫是opencv和Pillow(PIL),但是兩者讀取出來的圖像并不一樣, opencv讀取的圖像格式的三個通道是BGR形式的,但是PIL是RGB格式的 。這個問題看起來很小,但是衍生出來的坑可以有很多,最常見的場景就是數(shù)據(jù)增強和預(yù)訓(xùn)練模型中。比如有些數(shù)據(jù)增強的方法是基于channel維度的,比如megengine里面的HueTransform,這一行代碼 (https://github.com/MegEngine/MegEngine/blob/4d72e7071d6b8f8240edc56c6853384850b7407f/imperative/python/megengine/data/transform/vision/transform.py#L958 ) 顯然是需要確保圖像是BGR的,但是經(jīng)常會有人只看有Transform就無腦用了,從來沒有考慮過這些問題。
          2. 接上條,RGB和BGR的另一個問題就是導(dǎo)致預(yù)訓(xùn)練模型載入后訓(xùn)練的方式不對,最常見的場景就是預(yù)訓(xùn)練模型的input channel是RGB的(例如torch官方來的預(yù)訓(xùn)練模型),然后你用cv2做數(shù)據(jù)處理,最后還忘了convert成RGB的格式,那么就是會有問題。這個問題應(yīng)該很多煉丹的同學(xué)沒有注意過,我之前寫CenterNet-better(https://github.com/FateScript/CenterNet-better)就發(fā)現(xiàn)CenterNet(https://github.com/xingyizhou/CenterNet)存在這么一個問題,要知道當時這可是一個有著3k多star的倉庫,但是從來沒有人意識到有這個問題。當然,依照我的經(jīng)驗,如果你訓(xùn)練的iter足夠多,即使你的channel有問題,對于結(jié)果的影響也會非常小。不過,既然能做對,為啥不注意這些問題一次性做對呢?
          3. torchvision中提供的模型,都是輸入圖像經(jīng)過了ToTensor操作train出來的。也就是說最后在進入網(wǎng)絡(luò)之前會統(tǒng)一除以255從而將網(wǎng)絡(luò)的輸入變到0到1之間。torchvision的文檔(https://pytorch.org/vision/stable/models.html)給出了他們使用的mean和std,也是0-1的mean和std。如果你使用torch預(yù)訓(xùn)練的模型,但是輸入還是0-255的,那么恭喜你,在載入模型上你又會踩一個大坑(要么你的圖像先除以255,要么你的code中mean和std的數(shù)值都要乘以255)。
          4. ToTensor之后接數(shù)據(jù)處理的坑。上一條說了ToTensor之后圖像變成了0到1的,但是一些數(shù)據(jù)增強對數(shù)值做處理的時候,是針對標準圖像,很多人ToTensor之后接了這樣一個數(shù)據(jù)增強,最后就是練出來的丹是廢的(心疼電費QaQ)。
          5. 數(shù)據(jù)集里面有一個圖特別詭異,只要train到那一張圖就會炸顯存(CUDA OOM),別的圖訓(xùn)練起來都沒有問題,應(yīng)該怎么處理?通常出現(xiàn)這個問題,首先判斷數(shù)據(jù)本身是不是有問題。如果數(shù)據(jù)本身有問題,在一開始生成Dataset對象的時候去掉就行了。如果數(shù)據(jù)本身沒有問題,只不過因為一些特殊原因?qū)е嘛@存炸了(比如檢測中圖像的GT boxes過多的問題),可以catch一個CUDA OOM的error之后將一些邏輯放在CPU上,最后retry一下,這樣只是會慢一個iter,但是訓(xùn)練過程還是可以完整走完的,在我們開源的YOLOX里有類似的參考code(https://github.com/Megvii-BaseDetection/YOLOX/blob/0.1.0/yolox/models/yolo_head.py#L330-L334)。
          6. pytorch中dataloader的坑。有時候會遇到pytorch num_workers=0(也就是單進程)沒有問題,但是多進程就會報一些看不懂的錯的現(xiàn)象,這種情況通常是因為torch到了ulimit的上限,更核心的原因是 torch的dataloader不會釋放文件描述符 (參考issue: https://github.com/pytorch/pytorch/issues/973)。可以ulimit -n 看一下機器的設(shè)置。跑程序之前修改一下對應(yīng)的數(shù)值。
          7. opencv和dataloader的神奇聯(lián)動。很多人經(jīng)常來問為啥要寫cv2.setNumThreads(0),其實是因為cv2在做resize等op的時候會用多線程,當torch的dataloader是多進程的時候,多進程套多線程,很容易就卡死了(具體哪里死鎖了我沒探究很深)。除了setNumThreads之外,通常還要加一句cv2.ocl.setUseOpenCL(False),原因是cv2使用opencl和cuda一起用的時候通常會拖慢速度,加了萬事大吉,說不定還能加速。感謝評論區(qū) @Yuxin Wu(https://www.zhihu.com/people/ppwwyyxx)?大大的指正
          8. dataloader會在epoch結(jié)束之后進行類似重新加載的操作,復(fù)現(xiàn)這個問題的code稍微有些長,放在后面了。這個問題算是可以說是一個高級bug/feature了,可能導(dǎo)致的問題之一就是煉丹師在本地的code上進行了一些修改,然后訓(xùn)練過程直接加載進去了。解決方法也很簡單,讓你的sampler源源不斷地產(chǎn)生數(shù)據(jù)就好,這樣即使本地code有修改也不會加載進去。

          Module模塊

          1. BatchNorm在訓(xùn)練和推斷的時候的行為是不一致的。這也是新人最常見的錯誤(類似的算子還有dropout,這里提一嘴, pytorch的dropout在eval的時候行為是Identity ,之前有遇到過實習(xí)生說dropout加了沒效果,直到我看了他的code:x = F.dropout(x, p=0.5)
          2. BatchNorm疊加分布式訓(xùn)練的坑。在使用DDP(DistributedDataParallel)進行訓(xùn)練的時候,每張卡上的BN統(tǒng)計量是可能不一樣的,仔細檢查broadcast_buffer這個參數(shù) 。DDP的默認行為是在forward之前將rank0 的 buffer做一次broadcast(broadcast_buffer=True),但是一些常用的開源檢測倉庫是將broadcast_buffer設(shè)置成False的(參考:mmdet(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206)?和 detectron2(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206),我猜是在檢測任務(wù)中因為batchsize過小,統(tǒng)一用卡0的統(tǒng)計量會掉點) 這個問題在一邊訓(xùn)練一邊測試的code中更常見 ,比如說你train了5個epoch,然后要分布式測試一下。一般的邏輯是將數(shù)據(jù)集分到每塊卡上,每塊卡進行inference,最后gather到卡0上進行測點。但是 因為每張卡統(tǒng)計量是不一樣的,所以和那種把卡0的模型broadcast到不同卡上測試出來的結(jié)果是不一樣的。這也是為啥通常訓(xùn)練完測的點和單獨起了一個測試腳本跑出來的點不一樣的原因 (當然你用SyncBN就不會有這個問題)。
          3. Pytorch的SyncBN在1.5之前一直實現(xiàn)的有bug,所以有一些老倉庫是存在使用SyncBN之后掉點的問題的。
          4. 用了多卡開多尺度訓(xùn)練,明明尺度更小了,但是速度好像不是很理想?這個問題涉及到多卡的原理,因為分布式訓(xùn)練的時候,在得到新的參數(shù)之后往往需要進行一次同步。假設(shè)有兩張卡,卡0的尺度非常小,卡1的尺度非常大,那么就會出現(xiàn)卡0始終在等卡1,于是就出現(xiàn)了雖然有的尺度變小了,但是整體的訓(xùn)練速度并沒有變快的現(xiàn)象(木桶效應(yīng))。解決這個問題的思路就是 盡量把負載拉均衡一些
          5. 多卡的小batch模擬大batch(梯度累積)的坑。假設(shè)我們在單卡下只能塞下batchsize = 2,那么為了模擬一個batchsize = 8的效果,通常的做法是forward / backward 4次,不清理梯度,step一次(當然考慮BN的統(tǒng)計量問題這種做法和單純的batchsize=8肯定還是有一些差別的)。在多卡下,因為調(diào)用loss.backward的時候會做grad的同步,所以說前三次調(diào)用backward的時候需要加ddp.no_sync(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=no_sync#torch.nn.parallel.DistributedDataParallel.no_sync)的context manager(不加的話,第一次bp之后,各個卡上的grad此時會進行同步),最后一次則不需要加。當然,我看很多倉庫并沒有這么做,我只能理解他們就是單純想做梯度累積(BTW,加了ddp.no_sync會使得程序快一些,畢竟加了之后bp過程是無通訊的)。
          6. 浮點數(shù)的加法其實不遵守交換律的 ,這個通常能衍生出來GPU上的運算結(jié)果不能嚴格復(fù)現(xiàn)的現(xiàn)象。可能一些非計算機軟件專業(yè)的同學(xué)并不理解這一件事情,直接自己開一個python終端體驗可能會更好:


          print(1e100?+?1e-4?+?-1e100)??#?ouptut:?0
          print(1e100?+?-1e100?+?1e-4)??#?output:?0.0001

          訓(xùn)練模塊

          1. FP16訓(xùn)練/混合精度訓(xùn)練。使用Apex訓(xùn)練混合精度模型,在保存checkpoint用于繼續(xù)訓(xùn)練的時候,除了model和optimizer本身的state_dict之外,還需要保存一下amp的state_dict,這個在amp的文檔(https://link.zhihu.com/?target=https%3A//nvidia.github.io/apex/amp.html%23checkpointing)中也有提過。(當然,經(jīng)驗上來說忘了保存影響不大,會多花幾個iter search一個loss scalar出來)
          2. 多機分布式訓(xùn)練卡死的問題。好友 @NoahSYZhang(https://www.zhihu.com/people/syzhangbuaa) 遇到的一個坑。場景是先申請了兩個8卡機,然后機器1和機器2用前4塊卡做通訊(local rank最大都是4,總共是兩機8卡)。可以初始化process group,但是在使用DDP的時候會卡死。原因在于pytorch在做DDP的時候會猜測一個rank,參考code(https://github.com/pytorch/pytorch/blob/0d437fe6d0ef17648072eb586484a4a5a080b094/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1622-L1630)。對于上面的場景,第二個機器上因為存在卡5到卡8,而對應(yīng)的rank也是5到8,所以DDP就會認為自己需要同步的是卡5到卡8,于是就卡死了。

          復(fù)現(xiàn)Code

          Data部分


          from?torch.utils.data?import?DataLoader
          from?torch.utils.data?import?Dataset
          import?tqdm
          import?time


          class?SimpleDataset(Dataset):
          ????def?__init__(self,?length=400):
          ????????self.length?=?length
          ????????self.data_list?=?list(range(length))

          ????def?__getitem__(self,?index):
          ????????data?=?self.data_list[index]
          ????????time.sleep(0.1)
          ????????return?data

          ????def?__len__(self):
          ????????return?self.length


          def?train(local_rank):
          ????dataset?=?SimpleDataset()
          ????dataloader?=?DataLoader(dataset,?batch_size=1,?num_workers=2)
          ????iter_loader?=?iter(dataloader)
          ????max_iter?=?100000
          ????for?_?in?tqdm.tqdm(range(max_iter)):
          ????????try:
          ????????????_?=?next(iter_loader)
          ????????except?StopIteration:
          ????????????print("Refresh?here?!!!!!!!!")
          ????????????iter_loader?=?iter(dataloader)
          ????????????_?=?next(iter_loader)


          if?__name__?==?"__main__":
          ????import?torch.multiprocessing?as?mp
          ????mp.spawn(train,?args=(),?nprocs=2,?daemon=False)

          當程序運行起來的時候,可以在Dataset里面的__getitem__方法里面加一個print輸出一些內(nèi)容,在refresh之后,就會print對應(yīng)的內(nèi)容哦(看到現(xiàn)象是不是覺得自己以前煉的丹可能有問題了呢hhh)

          一些碎碎念

          一口氣寫了這么多條也有點累了,后續(xù)有踩到新坑的話我也會繼續(xù)更新這篇文章的。畢竟寫這篇文章是希望工作中不再會有人踩類似的坑 & 煉丹的人能夠?qū)ι疃葘W(xué)習(xí)框架有意識(雖然某種程度上來講這算是個心智負擔(dān))。

          如果說今年來什么事情是最大的收獲的話,那就是理解了一個開放的生態(tài)是可以迸發(fā)出極強的活力的,也希望能看到更多的人來分享自己遇到的問題和解決的思路。畢竟探索的答案只是一個副產(chǎn)品,過程本身才是最大的財寶。



          猜您喜歡:

          超110篇!CVPR 2021最全GAN論文匯總梳理!

          超100篇!CVPR 2020最全GAN論文梳理匯總!

          拆解組新的GAN:解耦表征MixNMatch

          StarGAN第2版:多域多樣性圖像生成


          附下載 |?《可解釋的機器學(xué)習(xí)》中文版

          附下載 |《TensorFlow 2.0 深度學(xué)習(xí)算法實戰(zhàn)》

          附下載 |《計算機視覺中的數(shù)學(xué)方法》分享


          《基于深度學(xué)習(xí)的表面缺陷檢測方法綜述》

          《零樣本圖像分類綜述: 十年進展》

          《基于深度神經(jīng)網(wǎng)絡(luò)的少樣本學(xué)習(xí)綜述》



          瀏覽 91
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  国内一级性爱网站 | 成人免费精品视频 | 伊人大香蕉视频在线观看 | 亚洲AV无码精品成人影院麻豆 | 日本一级做a |