實(shí)踐教程 | 解決pytorch半精度amp訓(xùn)練nan問(wèn)題

極市導(dǎo)讀
?本文主要是收集了一些在使用pytorch自帶的amp下loss nan的情況及對(duì)應(yīng)處理方案。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿
Why?
如果要解決問(wèn)題,首先就要明確原因:為什么全精度訓(xùn)練時(shí)不會(huì)nan,但是半精度就開(kāi)始nan?這其實(shí)分了三種情況:
計(jì)算loss 時(shí),出現(xiàn)了除以0的情況 loss過(guò)大,被半精度判斷為inf 網(wǎng)絡(luò)參數(shù)中有nan,那么運(yùn)算結(jié)果也會(huì)輸出nan
1&2我想放到后面討論,因?yàn)槠鋵?shí)大部分報(bào)nan都是第三種情況。這里來(lái)先看看3。什么情況下會(huì)出現(xiàn)情況3?這個(gè)討論給出了不錯(cuò)的解釋?zhuān)?/p>
【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17
給大家翻譯翻譯:在使用ce loss 或者 bceloss的時(shí)候,會(huì)有l(wèi)og的操作,在半精度情況下,一些非常小的數(shù)值會(huì)被直接舍入到0,log(0)等于啥?——等于nan啊!
于是邏輯就理通了:回傳的梯度因?yàn)閘og而變?yōu)閚an->網(wǎng)絡(luò)參數(shù)nan-> 每輪輸出都變成nan。(;′Д`)
How?
問(wèn)題定義清楚,那解決方案就非常簡(jiǎn)單了,只需要在涉及到log計(jì)算時(shí),把輸入從half精度轉(zhuǎn)回float32:
x = x.float()
x_sigmoid = torch.sigmoid(x)
一些思考&廢話(huà)
這里我接著討論下我第一次看到nan之后,企圖直接copy別人的解決方案,但解決不掉時(shí)踩過(guò)的坑。比如:
修改優(yōu)化器的eps
有些blog會(huì)建議你從默認(rèn)的1e-8 改為 1e-3,比如這篇:【pytorch1.1 半精度訓(xùn)練 Adam RMSprop 優(yōu)化器 Nan 問(wèn)題】https://link.zhihu.com/?target=https%3A//blog.csdn.net/gwb281386172/article/details/104705195
經(jīng)過(guò)上面的分析,我們就能知道為什么這種方法不行——這個(gè)方案是針對(duì)優(yōu)化器的數(shù)值穩(wěn)定性做的修改,而loss計(jì)算這一步在優(yōu)化器之前,如果loss直接nan,優(yōu)化器的eps是救不回來(lái)的(托腮)。
那么這個(gè)方案在哪些場(chǎng)景下有效?——在loss輸出不是nan時(shí)(感覺(jué)說(shuō)了一句廢話(huà))。optimizer的eps是保證在進(jìn)行除法backwards時(shí),分母不出現(xiàn)0時(shí)需要加上的微小量。在半精度情況下,分母加上1e-8就仿佛聽(tīng)君一席話(huà),因此,需要把eps調(diào)大一點(diǎn)。
聊聊amp的GradScaler
GradScaler是autocast的好伙伴,在官方教程上就和autocast配套使用:
from torch.cuda.amp import autocast, GradScaler
...
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
具體原理不是我這篇文章討論的范圍,網(wǎng)上很多教程都說(shuō)得很清楚了,比如這個(gè)就不錯(cuò):
【Gemfield:PyTorch的自動(dòng)混合精度(AMP)】https://zhuanlan.zhihu.com/p/165152789
但是我這里想討論另一點(diǎn):scaler.step(optimizer)的運(yùn)行原理。
在初始化GradScaler的時(shí)候,有一個(gè)參數(shù)enabled,值默認(rèn)為T(mén)rue。如果為T(mén)rue,那么在調(diào)用scaler方法時(shí)會(huì)做梯度縮放來(lái)調(diào)整loss,以防半精度狀況下,梯度值過(guò)大或者過(guò)小從而被nan或者inf。而且,它還會(huì)判斷本輪loss是否是nan,如果是,那么本輪計(jì)算的梯度不會(huì)回傳,同時(shí),當(dāng)前的scale系數(shù)乘上backoff_factor,縮減scale的大小_。_
那么,為什么這一步已經(jīng)判斷了loss是不是nan,還是會(huì)出現(xiàn)網(wǎng)絡(luò)損失持續(xù)nan的情況呢?
這時(shí)我們就得再往前思考一步了:為什么loss會(huì)變成nan?回到文章一開(kāi)始說(shuō)的:
(1)計(jì)算loss 時(shí),出現(xiàn)了除以0的情況;
(2)loss過(guò)大,被半精度判斷為inf;
(3)網(wǎng)絡(luò)直接輸出了nan。
(1)&(2),其實(shí)是可以通過(guò)scaler.step(optimizer)解決的,分別由optimizer和scaler幫我們捕捉到了nan的異常。但(3)不行,(3)意味著部分甚至全部的網(wǎng)絡(luò)參數(shù)已經(jīng)變成nan了。這可能是在更之前的梯度回傳過(guò)程中除以0導(dǎo)致的——首先【回傳的梯度不是nan】,所以scaler不會(huì)捕捉異常;其次,由于使用了半精度,optimizer接收到了【已經(jīng)因?yàn)榫葥p失而變?yōu)閚an的loss】,nan不管加上多大的eps,都還是nan,所以optimizer也無(wú)法處理異常,最終導(dǎo)致網(wǎng)絡(luò)參數(shù)nan。
所以3,只能通過(guò)本文一開(kāi)始提出的方案來(lái)解決。其實(shí),大部分分類(lèi)問(wèn)題在使用半精度時(shí)出現(xiàn)nan的情況都是第3種情況,也只能通過(guò)把精度轉(zhuǎn)回為float32,或者在計(jì)算log時(shí)加上微小量來(lái)避免(但這樣會(huì)損失精度)。
參考
【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17
如果覺(jué)得有用,就請(qǐng)分享到朋友圈吧!
公眾號(hào)后臺(tái)回復(fù)“transformer”獲取最新Transformer綜述論文下載~

#?CV技術(shù)社群邀請(qǐng)函?#

備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)
即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與?10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺(jué)開(kāi)發(fā)者互動(dòng)交流~

