語義分割中的 loss function 最全面匯總

極市導(dǎo)讀
?本文總結(jié)了語義分割中的5個損失函數(shù),詳細(xì)介紹每個損失函數(shù)的使用場景以及特點。?>>加入極市CV技術(shù)交流群,走在計算機視覺的最前沿
目錄:
cross entropy loss weighted loss focal loss dice soft loss soft iou loss 總結(jié)
1、cross entropy loss
用于圖像語義分割任務(wù)的最常用損失函數(shù)是像素級別的交叉熵?fù)p失,這種損失會逐個檢查每個像素,將對每個像素類別的預(yù)測結(jié)果(概率分布向量)與我們的獨熱編碼標(biāo)簽向量進行比較。
假設(shè)我們需要對每個像素的預(yù)測類別有5個,則預(yù)測的概率分布向量長度為5:

每個像素對應(yīng)的損失函數(shù)為:
整個圖像的損失就是對每個像素的損失求平均值。
特別注意的是,binary entropy loss 是針對類別只有兩個的情況,簡稱 bce loss,損失函數(shù)公式為:
2、weighted loss
由于交叉熵?fù)p失會分別評估每個像素的類別預(yù)測,然后對所有像素的損失進行平均,因此我們實質(zhì)上是在對圖像中的每個像素進行平等地學(xué)習(xí)。如果多個類在圖像中的分布不均衡,那么這可能導(dǎo)致訓(xùn)練過程由像素數(shù)量多的類所主導(dǎo),即模型會主要學(xué)習(xí)數(shù)量多的類別樣本的特征,并且學(xué)習(xí)出來的模型會更偏向?qū)⑾袼仡A(yù)測為該類別。
FCN論文和U-Net論文中針對這個問題,對輸出概率分布向量中的每個值進行加權(quán),即希望模型更加關(guān)注數(shù)量較少的樣本,以緩解圖像中存在的類別不均衡問題。
比如對于二分類,正負(fù)樣本比例為1: 99,此時模型將所有樣本都預(yù)測為負(fù)樣本,那么準(zhǔn)確率仍有99%這么高,但其實該模型沒有任何使用價值。
為了平衡這個差距,就對正樣本和負(fù)樣本的損失賦予不同的權(quán)重,帶權(quán)重的二分類損失函數(shù)公式如下:
要減少假陰性樣本的數(shù)量,可以增大 pos_weight;要減少假陽性樣本的數(shù)量,可以減小 pos_weight。
3、focal loss
上面針對不同類別的像素數(shù)量不均衡提出了改進方法,但有時還需要將像素分為難學(xué)習(xí)和容易學(xué)習(xí)這兩種樣本。
容易學(xué)習(xí)的樣本模型可以很輕松地將其預(yù)測正確,模型只要將大量容易學(xué)習(xí)的樣本分類正確,loss就可以減小很多,從而導(dǎo)致模型不怎么顧及難學(xué)習(xí)的樣本,所以我們要想辦法讓模型更加關(guān)注難學(xué)習(xí)的樣本。
對于較難學(xué)習(xí)的樣本,將 bce loss 修改為:
其中的 通常設(shè)置為2。
舉個例子,預(yù)測一個正樣本,如果預(yù)測結(jié)果為0.95,這是一個容易學(xué)習(xí)的樣本,有 ,損失直接減少為原來的1/400。
而如果預(yù)測結(jié)果為0.4,這是一個難學(xué)習(xí)的樣本,有 ,損失減小為原來的1/4,雖然也在減小,但是相對來說,減小的程度小得多。
所以通過這種修改,就可以使模型更加專注于學(xué)習(xí)難學(xué)習(xí)的樣本。
而將這個修改和對正負(fù)樣本不均衡的修改合并在一起,就是大名鼎鼎的 focal loss:
4、dice soft loss
語義分割任務(wù)中常用的還有一個基于 Dice 系數(shù)的損失函數(shù),該系數(shù)實質(zhì)上是兩個樣本之間重疊的度量。此度量范圍為 0~1,其中 Dice 系數(shù)為1表示完全重疊。Dice 系數(shù)最初是用于二進制數(shù)據(jù)的,可以計算為:
代表集合A和B之間的公共元素,并且 代表集合A中的元素數(shù)量(對于集合B同理)。
對于在預(yù)測的分割掩碼上評估 Dice 系數(shù),我們可以將 近似為預(yù)測掩碼和標(biāo)簽掩碼之間的逐元素乘法,然后對結(jié)果矩陣求和。

計算 Dice 系數(shù)的分子中有一個2,那是因為分母中對兩個集合的元素個數(shù)求和,兩個集合的共同元素被加了兩次。為了設(shè)計一個可以最小化的損失函數(shù),可以簡單地使用 。這種損失函數(shù)被稱為 soft Dice loss,這是因為我們直接使用預(yù)測出的概率,而不是使用閾值將其轉(zhuǎn)換成一個二進制掩碼。
Dice loss是針對前景比例太小的問題提出的,dice系數(shù)源于二分類,本質(zhì)上是衡量兩個樣本的重疊部分。
對于神經(jīng)網(wǎng)絡(luò)的輸出,分子與我們的預(yù)測和標(biāo)簽之間的共同激活有關(guān),而分母分別與每個掩碼中的激活數(shù)量有關(guān),這具有根據(jù)標(biāo)簽掩碼的尺寸對損失進行歸一化的效果。

對于每個類別的mask,都計算一個 Dice 損失:
將每個類的 Dice 損失求和取平均,得到最后的 Dice soft loss。
下面是代碼實現(xiàn):
def?soft_dice_loss(y_true,?y_pred,?epsilon=1e-6):?
????'''?
????Soft?dice?loss?calculation?for?arbitrary?batch?size,?number?of?classes,?and?number?of?spatial?dimensions.
????Assumes?the?`channels_last`?format.
??
????#?Arguments
????????y_true:?b?x?X?x?Y(?x?Z...)?x?c?One?hot?encoding?of?ground?truth
????????y_pred:?b?x?X?x?Y(?x?Z...)?x?c?Network?output,?must?sum?to?1?over?c?channel?(such?as?after?softmax)?
????????epsilon:?Used?for?numerical?stability?to?avoid?divide?by?zero?errors
????
????#?References
????????V-Net:?Fully?Convolutional?Neural?Networks?for?Volumetric?Medical?Image?Segmentation?
????????https://arxiv.org/abs/1606.04797
????????More?details?on?Dice?loss?formulation?
????????https://mediatum.ub.tum.de/doc/1395260/1395260.pdf?(page?72)
????????
????????Adapted?from?https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
????'''
????
????#?skip?the?batch?and?class?axis?for?calculating?Dice?score
????axes?=?tuple(range(1,?len(y_pred.shape)-1))?
????numerator?=?2.?*?np.sum(y_pred?*?y_true,?axes)
????denominator?=?np.sum(np.square(y_pred)?+?np.square(y_true),?axes)
????
????return?1?-?np.mean(numerator?/?(denominator?+?epsilon))?#?average?over?classes?and?batch
5、soft IoU loss
前面我們知道計算 Dice 系數(shù)的公式,其實也可以表示為:
其中 TP 為真陽性樣本,F(xiàn)P 為假陽性樣本,F(xiàn)N 為假陰性樣本。分子和分母中的 TP 樣本都加了兩次。
IoU 的計算公式和這個很像,區(qū)別就是 TP 只計算一次:
和 Dice soft loss 一樣,通過 IoU 計算損失也是使用預(yù)測的概率值:
其中 C 表示總的類別數(shù)。
總結(jié):
交叉熵?fù)p失把每個像素都當(dāng)作一個獨立樣本進行預(yù)測,而 dice loss 和 iou loss 則以一種更“整體”的方式來看待最終的預(yù)測輸出。
這兩類損失是針對不同情況,各有優(yōu)點和缺點,在實際應(yīng)用中,可以同時使用這兩類損失來進行互補。
參考:
An overview of semantic image segmentation.(https://www.jeremyjordan.me/semantic-segmentation/)
Loss Functions for Medical Image Segmentation(https://medium.com/@junma11/loss-functions-for-medical-image-segmentation-a-taxonomy-cefa5292eec0)
Losses for Image Segmentation(https://lars76.github.io/neural-networks/object-detection/losses-for-segmentation/)
公眾號后臺回復(fù)“CVPR 2022”獲取論文合集打包下載~

