<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          解決pytorch半精度amp訓練nan問題

          共 2731字,需瀏覽 6分鐘

           ·

          2021-12-18 04:42

          點擊上方視學算法”,選擇加"星標"或“置頂

          重磅干貨,第一時間送達

          作者 | 可可噠@知乎(已授權)?
          來源 | https://zhuanlan.zhihu.com/p/443166496?
          編輯 | 極市平臺

          導讀

          ?

          本文主要是收集了一些在使用pytorch自帶的amp下loss nan的情況及對應處理方案。?

          Why?

          如果要解決問題,首先就要明確原因:為什么全精度訓練時不會nan,但是半精度就開始nan?這其實分了三種情況:

          1. 計算loss 時,出現(xiàn)了除以0的情況
          2. loss過大,被半精度判斷為inf
          3. 網(wǎng)絡參數(shù)中有nan,那么運算結果也會輸出nan

          1&2我想放到后面討論,因為其實大部分報nan都是第三種情況。這里來先看看3。什么情況下會出現(xiàn)情況3?這個討論給出了不錯的解釋:

          【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

          給大家翻譯翻譯:在使用ce loss 或者 bceloss的時候,會有l(wèi)og的操作,在半精度情況下,一些非常小的數(shù)值會被直接舍入到0,log(0)等于啥?——等于nan啊!

          于是邏輯就理通了:回傳的梯度因為log而變?yōu)閚an->網(wǎng)絡參數(shù)nan-> 每輪輸出都變成nan。(;′Д`)

          How?

          問題定義清楚,那解決方案就非常簡單了,只需要在涉及到log計算時,把輸入從half精度轉回float32:

          x = x.float()
          x_sigmoid = torch.sigmoid(x)

          一些思考&廢話

          這里我接著討論下我第一次看到nan之后,企圖直接copy別人的解決方案,但解決不掉時踩過的坑。比如:

          1. 修改優(yōu)化器的eps

          有些blog會建議你從默認的1e-8 改為 1e-3,比如這篇:【pytorch1.1 半精度訓練 Adam RMSprop 優(yōu)化器 Nan 問題】https://link.zhihu.com/?target=https%3A//blog.csdn.net/gwb281386172/article/details/104705195

          經(jīng)過上面的分析,我們就能知道為什么這種方法不行——這個方案是針對優(yōu)化器的數(shù)值穩(wěn)定性做的修改,而loss計算這一步在優(yōu)化器之前,如果loss直接nan,優(yōu)化器的eps是救不回來的(托腮)。

          那么這個方案在哪些場景下有效?——在loss輸出不是nan時(感覺說了一句廢話)。optimizer的eps是保證在進行除法backwards時,分母不出現(xiàn)0時需要加上的微小量。在半精度情況下,分母加上1e-8就仿佛聽君一席話,因此,需要把eps調大一點。

          1. 聊聊amp的GradScaler

          GradScaler是autocast的好伙伴,在官方教程上就和autocast配套使用:

          from torch.cuda.amp import autocast, GradScaler
          ...
          scaler = GradScaler()
          for epoch in epochs:
          for input, target in data:
          optimizer.zero_grad()

          with autocast():
          output = model(input)
          loss = loss_fn(output, target)
          scaler.scale(loss).backward()

          scaler.step(optimizer)
          scaler.update()

          具體原理不是我這篇文章討論的范圍,網(wǎng)上很多教程都說得很清楚了,比如這個就不錯:

          【Gemfield:PyTorch的自動混合精度(AMP)】https://zhuanlan.zhihu.com/p/165152789

          但是我這里想討論另一點:scaler.step(optimizer)的運行原理。

          在初始化GradScaler的時候,有一個參數(shù)enabled,值默認為True。如果為True,那么在調用scaler方法時會做梯度縮放來調整loss,以防半精度狀況下,梯度值過大或者過小從而被nan或者inf。而且,它還會判斷本輪loss是否是nan,如果是,那么本輪計算的梯度不會回傳,同時,當前的scale系數(shù)乘上backoff_factor,縮減scale的大小_。_

          那么,為什么這一步已經(jīng)判斷了loss是不是nan,還是會出現(xiàn)網(wǎng)絡損失持續(xù)nan的情況呢?

          這時我們就得再往前思考一步了:為什么loss會變成nan?回到文章一開始說的:

          (1)計算loss 時,出現(xiàn)了除以0的情況;

          (2)loss過大,被半精度判斷為inf;

          (3)網(wǎng)絡直接輸出了nan。

          (1)&(2),其實是可以通過scaler.step(optimizer)解決的,分別由optimizer和scaler幫我們捕捉到了nan的異常。但(3)不行,(3)意味著部分甚至全部的網(wǎng)絡參數(shù)已經(jīng)變成nan了。這可能是在更之前的梯度回傳過程中除以0導致的——首先【回傳的梯度不是nan】,所以scaler不會捕捉異常;其次,由于使用了半精度,optimizer接收到了【已經(jīng)因為精度損失而變?yōu)閚an的loss】,nan不管加上多大的eps,都還是nan,所以optimizer也無法處理異常,最終導致網(wǎng)絡參數(shù)nan。

          所以3,只能通過本文一開始提出的方案來解決。其實,大部分分類問題在使用半精度時出現(xiàn)nan的情況都是第3種情況,也只能通過把精度轉回為float32,或者在計算log時加上微小量來避免(但這樣會損失精度)。

          參考

          【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

          如果覺得有用,就請分享到朋友圈吧!


          點個在看 paper不斷!

          瀏覽 143
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  狠狠做综合| 欧美一级爱爱视频 | 久久国产精品色综合 | 久久一二三区 | 曰韩一级A片 |