<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          實(shí)踐教程 | 解決pytorch半精度amp訓(xùn)練nan問(wèn)題

          共 3190字,需瀏覽 7分鐘

           ·

          2021-12-18 21:30

          ↑ 點(diǎn)擊藍(lán)字?關(guān)注極市平臺(tái)

          作者 | 可可噠@知乎(已授權(quán))?
          來(lái)源 | https://zhuanlan.zhihu.com/p/443166496?
          編輯 | 極市平臺(tái)

          極市導(dǎo)讀

          ?

          本文主要是收集了一些在使用pytorch自帶的amp下loss nan的情況及對(duì)應(yīng)處理方案。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿

          Why?

          如果要解決問(wèn)題,首先就要明確原因:為什么全精度訓(xùn)練時(shí)不會(huì)nan,但是半精度就開(kāi)始nan?這其實(shí)分了三種情況:

          1. 計(jì)算loss 時(shí),出現(xiàn)了除以0的情況
          2. loss過(guò)大,被半精度判斷為inf
          3. 網(wǎng)絡(luò)參數(shù)中有nan,那么運(yùn)算結(jié)果也會(huì)輸出nan

          1&2我想放到后面討論,因?yàn)槠鋵?shí)大部分報(bào)nan都是第三種情況。這里來(lái)先看看3。什么情況下會(huì)出現(xiàn)情況3?這個(gè)討論給出了不錯(cuò)的解釋?zhuān)?/p>

          【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

          給大家翻譯翻譯:在使用ce loss 或者 bceloss的時(shí)候,會(huì)有l(wèi)og的操作,在半精度情況下,一些非常小的數(shù)值會(huì)被直接舍入到0,log(0)等于啥?——等于nan啊!

          于是邏輯就理通了:回傳的梯度因?yàn)閘og而變?yōu)閚an->網(wǎng)絡(luò)參數(shù)nan-> 每輪輸出都變成nan。(;′Д`)

          How?

          問(wèn)題定義清楚,那解決方案就非常簡(jiǎn)單了,只需要在涉及到log計(jì)算時(shí),把輸入從half精度轉(zhuǎn)回float32:

          x = x.float()
          x_sigmoid = torch.sigmoid(x)

          一些思考&廢話(huà)

          這里我接著討論下我第一次看到nan之后,企圖直接copy別人的解決方案,但解決不掉時(shí)踩過(guò)的坑。比如:

          1. 修改優(yōu)化器的eps

          有些blog會(huì)建議你從默認(rèn)的1e-8 改為 1e-3,比如這篇:【pytorch1.1 半精度訓(xùn)練 Adam RMSprop 優(yōu)化器 Nan 問(wèn)題】https://link.zhihu.com/?target=https%3A//blog.csdn.net/gwb281386172/article/details/104705195

          經(jīng)過(guò)上面的分析,我們就能知道為什么這種方法不行——這個(gè)方案是針對(duì)優(yōu)化器的數(shù)值穩(wěn)定性做的修改,而loss計(jì)算這一步在優(yōu)化器之前,如果loss直接nan,優(yōu)化器的eps是救不回來(lái)的(托腮)。

          那么這個(gè)方案在哪些場(chǎng)景下有效?——在loss輸出不是nan時(shí)(感覺(jué)說(shuō)了一句廢話(huà))。optimizer的eps是保證在進(jìn)行除法backwards時(shí),分母不出現(xiàn)0時(shí)需要加上的微小量。在半精度情況下,分母加上1e-8就仿佛聽(tīng)君一席話(huà),因此,需要把eps調(diào)大一點(diǎn)。

          1. 聊聊amp的GradScaler

          GradScaler是autocast的好伙伴,在官方教程上就和autocast配套使用:

          from torch.cuda.amp import autocast, GradScaler
          ...
          scaler = GradScaler()
          for epoch in epochs:
          for input, target in data:
          optimizer.zero_grad()

          with autocast():
          output = model(input)
          loss = loss_fn(output, target)
          scaler.scale(loss).backward()

          scaler.step(optimizer)
          scaler.update()

          具體原理不是我這篇文章討論的范圍,網(wǎng)上很多教程都說(shuō)得很清楚了,比如這個(gè)就不錯(cuò):

          【Gemfield:PyTorch的自動(dòng)混合精度(AMP)】https://zhuanlan.zhihu.com/p/165152789

          但是我這里想討論另一點(diǎn):scaler.step(optimizer)的運(yùn)行原理。

          在初始化GradScaler的時(shí)候,有一個(gè)參數(shù)enabled,值默認(rèn)為T(mén)rue。如果為T(mén)rue,那么在調(diào)用scaler方法時(shí)會(huì)做梯度縮放來(lái)調(diào)整loss,以防半精度狀況下,梯度值過(guò)大或者過(guò)小從而被nan或者inf。而且,它還會(huì)判斷本輪loss是否是nan,如果是,那么本輪計(jì)算的梯度不會(huì)回傳,同時(shí),當(dāng)前的scale系數(shù)乘上backoff_factor,縮減scale的大小_。_

          那么,為什么這一步已經(jīng)判斷了loss是不是nan,還是會(huì)出現(xiàn)網(wǎng)絡(luò)損失持續(xù)nan的情況呢?

          這時(shí)我們就得再往前思考一步了:為什么loss會(huì)變成nan?回到文章一開(kāi)始說(shuō)的:

          (1)計(jì)算loss 時(shí),出現(xiàn)了除以0的情況;

          (2)loss過(guò)大,被半精度判斷為inf;

          (3)網(wǎng)絡(luò)直接輸出了nan。

          (1)&(2),其實(shí)是可以通過(guò)scaler.step(optimizer)解決的,分別由optimizer和scaler幫我們捕捉到了nan的異常。但(3)不行,(3)意味著部分甚至全部的網(wǎng)絡(luò)參數(shù)已經(jīng)變成nan了。這可能是在更之前的梯度回傳過(guò)程中除以0導(dǎo)致的——首先【回傳的梯度不是nan】,所以scaler不會(huì)捕捉異常;其次,由于使用了半精度,optimizer接收到了【已經(jīng)因?yàn)榫葥p失而變?yōu)閚an的loss】,nan不管加上多大的eps,都還是nan,所以optimizer也無(wú)法處理異常,最終導(dǎo)致網(wǎng)絡(luò)參數(shù)nan。

          所以3,只能通過(guò)本文一開(kāi)始提出的方案來(lái)解決。其實(shí),大部分分類(lèi)問(wèn)題在使用半精度時(shí)出現(xiàn)nan的情況都是第3種情況,也只能通過(guò)把精度轉(zhuǎn)回為float32,或者在計(jì)算log時(shí)加上微小量來(lái)避免(但這樣會(huì)損失精度)。

          參考

          【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

          如果覺(jué)得有用,就請(qǐng)分享到朋友圈吧!

          △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨

          公眾號(hào)后臺(tái)回復(fù)“transformer”獲取最新Transformer綜述論文下載~


          極市干貨
          課程/比賽:珠港澳人工智能算法大賽保姆級(jí)零基礎(chǔ)人工智能教程
          算法trick目標(biāo)檢測(cè)比賽中的tricks集錦從39個(gè)kaggle競(jìng)賽中總結(jié)出來(lái)的圖像分割的Tips和Tricks
          技術(shù)綜述:一文弄懂各種loss function工業(yè)圖像異常檢測(cè)最新研究總結(jié)(2019-2020)


          #?CV技術(shù)社群邀請(qǐng)函?#

          △長(zhǎng)按添加極市小助手
          添加極市小助手微信(ID : cvmart4)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


          即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與?10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺(jué)開(kāi)發(fā)者互動(dòng)交流~



          覺(jué)得有用麻煩給個(gè)在看啦~??
          瀏覽 208
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  w黄视频欧美精品韩日 | 69精品无码成人久久久久久 | 欧美三级片在线播放 | 親子亂子倫XXXX | 日本色视频一区二区 |