解決pytorch半精度amp訓練nan問題
點擊上方“視學算法”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
導讀
?本文主要是收集了一些在使用pytorch自帶的amp下loss nan的情況及對應處理方案。?
Why?
如果要解決問題,首先就要明確原因:為什么全精度訓練時不會nan,但是半精度就開始nan?這其實分了三種情況:
計算loss 時,出現(xiàn)了除以0的情況 loss過大,被半精度判斷為inf 網(wǎng)絡參數(shù)中有nan,那么運算結果也會輸出nan
1&2我想放到后面討論,因為其實大部分報nan都是第三種情況。這里來先看看3。什么情況下會出現(xiàn)情況3?這個討論給出了不錯的解釋:
【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17
給大家翻譯翻譯:在使用ce loss 或者 bceloss的時候,會有l(wèi)og的操作,在半精度情況下,一些非常小的數(shù)值會被直接舍入到0,log(0)等于啥?——等于nan啊!
于是邏輯就理通了:回傳的梯度因為log而變?yōu)閚an->網(wǎng)絡參數(shù)nan-> 每輪輸出都變成nan。(;′Д`)
How?
問題定義清楚,那解決方案就非常簡單了,只需要在涉及到log計算時,把輸入從half精度轉回float32:
x = x.float()
x_sigmoid = torch.sigmoid(x)
一些思考&廢話
這里我接著討論下我第一次看到nan之后,企圖直接copy別人的解決方案,但解決不掉時踩過的坑。比如:
修改優(yōu)化器的eps
有些blog會建議你從默認的1e-8 改為 1e-3,比如這篇:【pytorch1.1 半精度訓練 Adam RMSprop 優(yōu)化器 Nan 問題】https://link.zhihu.com/?target=https%3A//blog.csdn.net/gwb281386172/article/details/104705195
經(jīng)過上面的分析,我們就能知道為什么這種方法不行——這個方案是針對優(yōu)化器的數(shù)值穩(wěn)定性做的修改,而loss計算這一步在優(yōu)化器之前,如果loss直接nan,優(yōu)化器的eps是救不回來的(托腮)。
那么這個方案在哪些場景下有效?——在loss輸出不是nan時(感覺說了一句廢話)。optimizer的eps是保證在進行除法backwards時,分母不出現(xiàn)0時需要加上的微小量。在半精度情況下,分母加上1e-8就仿佛聽君一席話,因此,需要把eps調大一點。
聊聊amp的GradScaler
GradScaler是autocast的好伙伴,在官方教程上就和autocast配套使用:
from torch.cuda.amp import autocast, GradScaler
...
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
具體原理不是我這篇文章討論的范圍,網(wǎng)上很多教程都說得很清楚了,比如這個就不錯:
【Gemfield:PyTorch的自動混合精度(AMP)】https://zhuanlan.zhihu.com/p/165152789
但是我這里想討論另一點:scaler.step(optimizer)的運行原理。
在初始化GradScaler的時候,有一個參數(shù)enabled,值默認為True。如果為True,那么在調用scaler方法時會做梯度縮放來調整loss,以防半精度狀況下,梯度值過大或者過小從而被nan或者inf。而且,它還會判斷本輪loss是否是nan,如果是,那么本輪計算的梯度不會回傳,同時,當前的scale系數(shù)乘上backoff_factor,縮減scale的大小_。_
那么,為什么這一步已經(jīng)判斷了loss是不是nan,還是會出現(xiàn)網(wǎng)絡損失持續(xù)nan的情況呢?
這時我們就得再往前思考一步了:為什么loss會變成nan?回到文章一開始說的:
(1)計算loss 時,出現(xiàn)了除以0的情況;
(2)loss過大,被半精度判斷為inf;
(3)網(wǎng)絡直接輸出了nan。
(1)&(2),其實是可以通過scaler.step(optimizer)解決的,分別由optimizer和scaler幫我們捕捉到了nan的異常。但(3)不行,(3)意味著部分甚至全部的網(wǎng)絡參數(shù)已經(jīng)變成nan了。這可能是在更之前的梯度回傳過程中除以0導致的——首先【回傳的梯度不是nan】,所以scaler不會捕捉異常;其次,由于使用了半精度,optimizer接收到了【已經(jīng)因為精度損失而變?yōu)閚an的loss】,nan不管加上多大的eps,都還是nan,所以optimizer也無法處理異常,最終導致網(wǎng)絡參數(shù)nan。
所以3,只能通過本文一開始提出的方案來解決。其實,大部分分類問題在使用半精度時出現(xiàn)nan的情況都是第3種情況,也只能通過把精度轉回為float32,或者在計算log時加上微小量來避免(但這樣會損失精度)。
參考
【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17
如果覺得有用,就請分享到朋友圈吧!

點個在看 paper不斷!
