淺談LabelSmooth兩種實現(xiàn)及推導
【GiantPandaCV導語】
因為最近跑VIT的實驗,所以有用到timm的一些配置,在mixup的實現(xiàn)里面發(fā)現(xiàn)labelsmooth的實現(xiàn)是按照最基本的方法來的,與很多pytorch的實現(xiàn)略有不同,所以簡單做了一個推導。
一、交叉熵損失(CrossEntropyLoss)
先簡單講一下交叉熵損失,也是我們做分類任務(wù)里面最常用的一種損失,公式如下:
這里的 表示的是模型輸出的logits后經(jīng)過softmax的結(jié)果,shape為 , 表示的是對應(yīng)的label,經(jīng)常用onehot來表示,pytorch版本可以使用scalar表示,shape為 ,這里 表示為batchsize, 表示為向量長度。
可以簡單拆解為如下:
log_softmax
這個很簡單,就是做softmax后取對數(shù),公式如下:
NLLloss
這個玩意的全程叫做negative log-likelihood(負對數(shù)似然損失), 簡單解釋下: 假設(shè)需要求解一個分布 ? ,由于未知其表達式,所以先定義一個分布 ,通過 來使得 靠近 的分布。這里采用最大似然估計來進行求解, ,不斷的更新參數(shù) 使得 來自 的樣本 ? 在 中的概率越來越高。但是有個問題,連乘對于求導不友好,計算也過于復雜,所以可以對其取對數(shù),有
最大化對數(shù)似然函數(shù)就等效于最小化負對數(shù)似然函數(shù),所以加個負號,公式如下:
由于求loss的時候,采用的是onehot形式,除去當前類別為1其余都為0,所以有:
這個形式就和交叉熵形式一致,所以NLLLoss也叫CrossEntropyLoss。
二、LabelSmooth
由于Softmax會存在一個問題,就是Over Confidence,會使得模型對于弱項的照顧很少。LabelSmooth的作用就是為了降低Softmax所帶來的的高Confidence的影響,讓模型略微關(guān)注到低概率分布的權(quán)重。這樣做也會有點影響,最終預(yù)測的時候,模型輸出的置信度會稍微低一些,需要比較細致的閾值過濾。

假設(shè) ,表示對標簽進行平滑的數(shù)值,那么就有
這里 classes表示類別數(shù)量, target表示當前的類別,帶有l(wèi)abelsmooth的CELoss就變成了:
相比原始的CELoss,LabelSmoothCELoss則是每一項都會參與到loss計算。
三、公式推導
#?labelsmooth?
import?torch?
import?torch.nn?as?nn?
import?torch.nn.functional?as?F?
class?LabelSmoothingCrossEntropy(nn.Module):
????"""
????NLL?loss?with?label?smoothing.
????"""
????def?__init__(self,?smoothing=0.1):
????????"""
????????Constructor?for?the?LabelSmoothing?module.
????????:param?smoothing:?label?smoothing?factor
????????"""
????????super(LabelSmoothingCrossEntropy,?self).__init__()
????????assert?smoothing?1.0
????????self.smoothing?=?smoothing
????????self.confidence?=?1.?-?smoothing
????def?forward(self,?x,?target):
????????logprobs?=?F.log_softmax(x,?dim=-1)
????????nll_loss?=?-logprobs.gather(dim=-1,?index=target.unsqueeze(1))
????????nll_loss?=?nll_loss.squeeze(1)
????????smooth_loss?=?-logprobs.mean(dim=-1)
????????loss?=?self.confidence?*?nll_loss?+?self.smoothing?*?smooth_loss
????????return?loss.mean()
可以看到這個code的實現(xiàn)和公式有點出入,第一部分是self.confidence * nll_loss, 第二部分是self.smoothing * smooth_loss。我們將其展開為:
假設(shè) k為target,那么對于onehot來說除了 以外均為0,所以有:
進一步有組合 項:
最后可以寫成矩陣點乘的形式:
我們表示 為LabelSmooth后的標簽 ,和第二節(jié)中的設(shè)定對齊,所以得到的Loss就是原本的表達式:
與之對應(yīng)的timm中的mixup部分的LabelSmoothCELoss代碼如下:
def?one_hot(x,?num_classes,?on_value=1.,?off_value=0.,?device='cuda'):
????x?=?x.long().view(-1,?1)
????return?torch.full((x.size()[0],?num_classes),?off_value,?device=device).scatter_(1,?x,?on_value)
def?mixup_target(target,?num_classes,?lam=1.,?smoothing=0.0,?device='cuda'):
????off_value?=?smoothing?/?num_classes
????on_value?=?1.?-?smoothing?+?off_value
????y1?=?one_hot(target,?num_classes,?on_value=on_value,?off_value=off_value,?device=device)
????y2?=?one_hot(target.flip(0),?num_classes,?on_value=on_value,?off_value=off_value,?device=device)
????return?y1?*?lam?+?y2?*?(1.?-?lam)
四,總結(jié)
LabelSmooth可以用來標簽平滑,從公示推導方面來講,也可以充當正則的作用,尤其是針對難分類別的情況下,效果會表現(xiàn)更好一些。

