【CV】10分鐘理解Focal loss數(shù)學(xué)原理與Pytorch代碼
原文鏈接:https://amaarora.github.io/2020/06/29/FocalLoss.html
原文作者:Aman Arora
Focal loss 是一個(gè)在目標(biāo)檢測(cè)領(lǐng)域常用的損失函數(shù)。最近看到一篇博客,趁這個(gè)機(jī)會(huì),學(xué)習(xí)和翻譯一下,與大家一起交流和分享。
在這篇博客中,我們將會(huì)理解什么是Focal loss,并且什么時(shí)候應(yīng)該使用它。同時(shí)我們會(huì)深入理解下其背后的數(shù)學(xué)原理與pytorch 實(shí)現(xiàn).
什么是Focal loss,它是用來干嘛的? 為什么Focal loss有效,其中的原理是什么? Alpha and Gamma? 怎么在代碼中實(shí)現(xiàn)它? Credits
什么是Focal loss,它是用來干嘛的?
在了解什么是Focal Loss以及有關(guān)它的所有詳細(xì)信息之前,我們首先快速直觀地了解Focal Loss的實(shí)際作用。Focal loss最早是 He et al 在論文 Focal Loss for Dense Object Detection 中實(shí)現(xiàn)的。
在這篇文章發(fā)表之前,對(duì)象檢測(cè)實(shí)際上一直被認(rèn)為是一個(gè)很難解決的問題,尤其是很難檢測(cè)圖像中的小尺寸對(duì)象。請(qǐng)參見下面的示例,與其他圖片相比,摩托車的尺寸相對(duì)較小, 所以該模型無法很好地預(yù)測(cè)摩托車的存在。

因此,F(xiàn)ocal loss在樣本不平衡的情況下特別有用。特別是在“對(duì)象檢測(cè)”的情況下,大多數(shù)像素通常都是背景,圖像中只有很少數(shù)的像素具有我們感興趣的對(duì)象。
這是經(jīng)過Focal loss訓(xùn)練后同一模型對(duì)同樣圖片的預(yù)測(cè)。

那么為什么Focal loss有效,其中的原理是什么?
既然我們已經(jīng)看到了“Focal loss”可以做什么的一個(gè)例子,接下來讓我們嘗試去理解為什么它可以起作用。下面是了解Focal loss的最重要的一張圖:

在上圖中,“藍(lán)”線代表交叉熵?fù)p失。X軸即“預(yù)測(cè)為真實(shí)標(biāo)簽的概率”(為簡單起見,將其稱為pt)。舉例來說,假設(shè)模型預(yù)測(cè)某物是自行車的概率為0.6,而它確實(shí)是自行車, 在這種情況下的pt為0.6。而如果同樣的情況下對(duì)象不是自行車。則pt為0.4,因?yàn)榇颂幍恼鎸?shí)標(biāo)簽是0,而對(duì)象不是自行車的概率為0.4(1-0.6)。
Y軸是給定pt后Focal loss和CE的loss的值。
從圖像中可以看出,當(dāng)模型預(yù)測(cè)為真實(shí)標(biāo)簽的概率為0.6左右時(shí),交叉熵?fù)p失仍在0.5左右。因此,為了在訓(xùn)練過程中減少損失,我們的模型將必須以更高的概率來預(yù)測(cè)到真實(shí)標(biāo)簽。換句話說,交叉熵?fù)p失要求模型對(duì)自己的預(yù)測(cè)非常有信心。但這也同樣會(huì)給模型表現(xiàn)帶來負(fù)面影響。
深度學(xué)習(xí)模型會(huì)變得過度自信, 因此模型的泛化能力會(huì)下降.
這個(gè)模型過度自信的問題同樣在另一篇出色的論文 Beyond temperature scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet calibration 被強(qiáng)調(diào)過。
另外,作為重新思考計(jì)算機(jī)視覺的初始架構(gòu)的一部分而引入的標(biāo)簽平滑是解決該問題的另一種方法。
Focal loss與上述解決方案不同。從比較Focal loss與CrossEntropy的圖表可以看出,當(dāng)使用γ> 1的Focal Loss可以減少“分類得好的樣本”或者說“模型預(yù)測(cè)正確概率大”的樣本的訓(xùn)練損失,而對(duì)于“難以分類的示例”,比如預(yù)測(cè)概率小于0.5的,則不會(huì)減小太多損失。因此,在數(shù)據(jù)類別不平衡的情況下,會(huì)讓模型的注意力放在稀少的類別上,因?yàn)檫@些類別的樣本見過的少,比較難分。
Focal loss的數(shù)學(xué)定義如下:

Alpha and Gamma?
那么在Focal loss 中的alpha和gamma是什么呢?我們會(huì)將alpha記為α,gamma記為γ。
我們可以這樣來理解fig3
γ?控制曲線的形狀.?γ的值越大, 好分類樣本的loss就越小, 我們就可以把模型的注意力投向那些難分類的樣本. 一個(gè)大的?γ?讓獲得小loss的樣本范圍擴(kuò)大了.
同時(shí),當(dāng)γ=0時(shí),這個(gè)表達(dá)式就退化成了Cross Entropy Loss,眾所周知地

定義“ pt”如下,按照其真實(shí)意義:

將上述兩個(gè)式子合并,Cross Entropy Loss其實(shí)就變成了下式。

現(xiàn)在我們知道了γ的作用,那么α是干什么的呢?
除了Focal loss以外,另一種處理類別不均衡的方法是引入權(quán)重。給稀有類別以高權(quán)重,給統(tǒng)治地位的類或普通類以小權(quán)重。這些權(quán)重我們也可以用α表示。

alpha-CE
加上了這些權(quán)重確實(shí)幫助處理了類別的 不均衡,focal loss的論文報(bào)道:
類間不均衡較大會(huì)導(dǎo)致,交叉熵?fù)p失在訓(xùn)練的時(shí)候收到影響。易分類的樣本的分類錯(cuò)誤的損失占了整體損失的絕大部分,并主導(dǎo)梯度。盡管α平衡了正面/負(fù)面例子的重要性,但它并未區(qū)分簡單/困難例子。
作者想要解釋的是:
盡管我們加上了α, 它也確實(shí)對(duì)不同的類別加上了不同的權(quán)重, 從而平衡了正負(fù)樣本的重要性 ,但在大多數(shù)例子中,只做這個(gè)是不夠的. 我們同樣要做的是減少容易分類的樣本分類錯(cuò)誤的損失。因?yàn)椴蝗坏脑?,這些容易分類的樣本就主導(dǎo)了我們的訓(xùn)練.
那么Focal loss 怎么處理的呢,它相對(duì)交叉熵加上了一個(gè)乘性的因子(1 ? pt)**γ,從而像我們上面所講的,降低了易分類樣本區(qū)間內(nèi)產(chǎn)生的loss。
再看下Focal loss的表達(dá),是不是清晰了許多。

怎么在代碼中實(shí)現(xiàn)呢?
這是Focal loss在Pytorch中的實(shí)現(xiàn)。
class WeightedFocalLoss(nn.Module):"Non weighted version of Focal Loss"def __init__(self, alpha=.25, gamma=2):super(WeightedFocalLoss, self).__init__()self.alpha = torch.tensor([alpha, 1-alpha]).cuda()self.gamma = gammadef forward(self, inputs, targets):BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')targets = targets.type(torch.long)at = self.alpha.gather(0, targets.data.view(-1))pt = torch.exp(-BCE_loss)F_loss = at*(1-pt)**self.gamma * BCE_lossreturn F_loss.mean()
如果你理解了alpha和gamma的意思,那么這個(gè)實(shí)現(xiàn)應(yīng)該都能理解。同時(shí),像文章中提到的一樣,這里是對(duì)BCE進(jìn)行因子的相乘。
Credits
貼上作者的 twitter ,當(dāng)然如果大家有什么問題討論,也可以在公眾號(hào)留言。
fig-1?and?fig-2?are from the?Fastai 2018 course?Lecture-09!
未完待續(xù)
今天給大家分享到這里,感謝大家的閱讀和支持,我們會(huì)繼續(xù)給大家分享我們的所思所想所學(xué),希望大家都有收獲!
往期精彩回顧
獲取一折本站知識(shí)星球優(yōu)惠券,復(fù)制鏈接直接打開:
https://t.zsxq.com/yFQV7am
本站qq群1003271085。
加入微信群請(qǐng)掃碼進(jìn)群:
