人臉識(shí)別損失函數(shù)的匯總 | Pytorch版本實(shí)現(xiàn)
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
這篇文章的重點(diǎn)不在于講解FR的各種Loss,因?yàn)橹跎弦呀?jīng)有很多,搜一下就好,本文主要提供了各種Loss的Pytorch實(shí)現(xiàn)以及Mnist的可視化實(shí)驗(yàn),一方面讓大家借助代碼更深刻地理解Loss的設(shè)計(jì),另一方面直觀的比較各種Loss的有效性,是否漲點(diǎn)并不是我關(guān)注的重點(diǎn),因?yàn)檫@些Loss的設(shè)計(jì)理念之一就是增大收斂難度,所以在Mnist這樣的簡(jiǎn)單任務(wù)上訓(xùn)練同樣的epoch,先進(jìn)的Loss并不一定能帶來點(diǎn)數(shù)的提升,但從視覺效果可以明顯的看出特征的分離程度,而且從另一方面來說,分類正確不代表一定能能在用歐式/余弦距離做1:1驗(yàn)證的時(shí)候也正確...
本文主要仿照CenterLoss文中的實(shí)驗(yàn)結(jié)構(gòu),使用了一個(gè)相對(duì)復(fù)雜一些的LeNet升級(jí)版網(wǎng)絡(luò),把輸入圖片Embedding成2維特征向量以便于可視化。
對(duì)了,代碼里用到了TensorBoardX來可視化,當(dāng)然如果你沒裝,可以注釋掉相關(guān)代碼,我也寫了本地保存圖片,雖然很不喜歡TensorFlow,但TensorBoard還是真香,比Visdom強(qiáng)太多了...
早就想寫這篇文章了,趁著五一假期終于...
具體代碼在Github:github.com/MccreeZhao/F 有興趣的話點(diǎn)個(gè)Star呀~雖然剛起步還沒什么東西
文章里只展示loss寫法
Softmax
公式推導(dǎo)

Pytorch代碼實(shí)現(xiàn)
class Linear(nn.Module):def __init__(self):super(Linear, self).__init__()self.weight = nn.Parameter(torch.Tensor(2, 10)) # (input,output)nn.init.xavier_uniform_(self.weight)def forward(self, x, label):out = x.mm(self.weight)loss = F.cross_entropy(out, label)return out, loss
emmm...現(xiàn)實(shí)生活中根本沒人會(huì)這么寫好吧!明明就有現(xiàn)成的Linear層啊喂!
寫成這樣只是為了方便統(tǒng)一框架...
可視化

這一張圖是二維化的特征,注意觀察不同兩類任意點(diǎn)之間的余弦距離和歐氏距離

這張圖是將特征歸一化的結(jié)果,能更好的反映余弦距離,豎線是該類在最后一個(gè)FC層的權(quán)重,等同于類別中心(這一點(diǎn)對(duì)于理解loss的發(fā)展還是挺關(guān)鍵的)
后面的圖片也都是這種形式,大家可以比較著來看
Modified Softmax
公式推導(dǎo)

去除了權(quán)重的模長(zhǎng)和偏置對(duì)loss的影響,將特征映射到了超球面,同時(shí)避免了樣本量差異帶來的預(yù)測(cè)傾向性(樣本量大可能導(dǎo)致權(quán)重模長(zhǎng)偏大)
Pytorch代碼實(shí)現(xiàn)
class Modified(nn.Module):def __init__(self):super(Modified, self).__init__()self.weight = nn.Parameter(torch.Tensor(2,10))#(input,output)nn.init.xavier_uniform_(self.weight)self.weight.data.uniform_(-1,1).renorm_(2,1,1e-5).mul_(1e5)#因?yàn)閞enorm采用的是maxnorm,所以先縮小再放大以防止norm結(jié)果小于1def forward(self, x):w=self.weightww=w.renorm(2,1,1e-5).mul(1e5)out = x.mm(ww)return out
可視化


這里要提一句,如果大家留心的話可以發(fā)現(xiàn),雖然modified loss并沒有太好的聚攏效果,但確讓類別中心準(zhǔn)確地落在了feature的中心,這對(duì)于網(wǎng)絡(luò)的性能是有很大好處的,但是具體原因我沒想出來...希望能有大佬在評(píng)論區(qū)給解釋一下...
NormFace
既然權(quán)重的模長(zhǎng)有影響,F(xiàn)eature的模長(zhǎng)必然也有影響,具體還是看文章,另外,質(zhì)量差的圖片feature模長(zhǎng)往往較短,做normalize之后消除了這個(gè)影響,有利有弊,還沒有達(dá)成一致觀點(diǎn),目前主流的Loss還是包括feature normalize的
公式推導(dǎo)

可視化


就是一個(gè)字:猛!感覺有了NormFace,后面的花式Loss都體現(xiàn)不出來效果了...
Pytorch代碼實(shí)現(xiàn)
class NormFace(nn.Module):def __init__(self):super(NormFace, self).__init__()self.weight = nn.Parameter(torch.Tensor(2, 10)) # (input,output)nn.init.xavier_uniform_(self.weight)self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)self.s = 16# 因?yàn)閞enorm采用的是maxnorm,所以先縮小再放大以防止norm結(jié)果小于1def forward(self, x, label):cosine = F.normalize(x).mm(F.normalize(self.weight, dim=0))loss = F.cross_entropy(self.s * cosine, label)return cosine, loss
SphereFace:A-softmax
為了進(jìn)一步約束特征向量之間的余弦距離,我們?nèi)藶榈卦黾邮諗侩y度,給兩個(gè)向量之間的夾角乘上一個(gè)因子:m
公式推導(dǎo)

Pytorch代碼實(shí)現(xiàn)
class SphereFace(nn.Module):def __init__(self, m=4):super(SphereFace, self).__init__()self.weight = nn.Parameter(torch.Tensor(2, 10)) # (input,output)nn.init.xavier_uniform_(self.weight)self.weight.data.renorm_(2, 1, 1e-5).mul_(1e5)self.m = mself.mlambda = [ # calculate cos(mx)lambda x: x ** 0,lambda x: x ** 1,lambda x: 2 * x ** 2 - 1,lambda x: 4 * x ** 3 - 3 * x,lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x]self.it = 0self.LambdaMin = 3self.LambdaMax = 30000.0self.gamma = 0def forward(self, input, label):# 注意,在原始的A-softmax中是不對(duì)x進(jìn)行標(biāo)準(zhǔn)化的,# 標(biāo)準(zhǔn)化可以提升性能,也會(huì)增加收斂難度,A-softmax本來就很難收斂cos_theta = F.normalize(input).mm(F.normalize(self.weight, dim=0))cos_theta = cos_theta.clamp(-1, 1) # 防止出現(xiàn)異常# 以上計(jì)算出了傳統(tǒng)意義上的cos_theta,但為了cos(m*theta)的單調(diào)遞減,需要使用phi_thetacos_m_theta = self.mlambda[self.m](cos_theta)# 計(jì)算theta,依據(jù)theta的區(qū)間把k的取值定下來theta = cos_theta.data.acos()k = (self.m * theta / 3.1415926).floor()phi_theta = ((-1) ** k) * cos_m_theta - 2 * kx_norm = input.pow(2).sum(1).pow(0.5) # 這個(gè)地方?jīng)Q定x帶不帶模長(zhǎng),不帶就要乘sx_cos_theta = cos_theta * x_norm.view(-1, 1)x_phi_theta = phi_theta * x_norm.view(-1, 1)############ 以上計(jì)算target logit,下面構(gòu)造loss,退火訓(xùn)練#####self.it += 1 # 用來調(diào)整lambdatarget = label.view(-1, 1) # (B,1)onehot = torch.zeros(target.shape[0], 10).cuda().scatter_(1, target, 1)lamb = max(self.LambdaMin, self.LambdaMax / (1 + 0.2 * self.it))output = x_cos_theta * 1.0 # 如果不乘可能會(huì)有數(shù)值錯(cuò)誤?output[onehot.byte()] -= x_cos_theta[onehot.byte()] * (1.0 + 0) / (1 + lamb)output[onehot.byte()] += x_phi_theta[onehot.byte()] * (1.0 + 0) / (1 + lamb)# 到這一步可以等同于原來的Wx+b=y的輸出了,# 到這里使用了Focal Loss,如果直接使用cross_Entropy的話似乎效果會(huì)減弱許多log = F.log_softmax(output, 1)log = log.gather(1, target)log = log.view(-1)pt = log.data.exp()loss = -1 * (1 - pt) ** self.gamma * logloss = loss.mean()# loss = F.cross_entropy(x_cos_theta,target.view(-1))#換成crossEntropy效果會(huì)差return output, loss
可視化


InsightFace(ArcSoftmax)
公式推導(dǎo)

Pytorch代碼實(shí)現(xiàn)
class ArcMarginProduct(nn.Module):def __init__(self, s=32, m=0.5):super(ArcMarginProduct, self).__init__()self.in_feature = 2self.out_feature = 10self.s = sself.m = mself.weight = nn.Parameter(torch.Tensor(2, 10)) # (input,output)nn.init.xavier_uniform_(self.weight)self.weight.data.renorm_(2, 1, 1e-5).mul_(1e5)self.cos_m = math.cos(m)self.sin_m = math.sin(m)# 為了保證cos(theta+m)在0-pi單調(diào)遞減:self.th = math.cos(3.1415926 - m)self.mm = math.sin(3.1415926 - m) * mdef forward(self, x, label):cosine = F.normalize(x).mm(F.normalize(self.weight, dim=0))cosine = cosine.clamp(-1, 1) # 數(shù)值穩(wěn)定sine = torch.sqrt(torch.max(1.0 - torch.pow(cosine, 2), torch.ones(cosine.shape).cuda() * 1e-7)) # 數(shù)值穩(wěn)定##print(self.sin_m)phi = cosine * self.cos_m - sine * self.sin_m # 兩角和公式# # 為了保證cos(theta+m)在0-pi單調(diào)遞減:# phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)#必要性未知#one_hot = torch.zeros_like(cosine)one_hot.scatter_(1, label.view(-1, 1), 1)output = (one_hot * phi) + ((1.0 - one_hot) * cosine)output = output * self.sloss = F.cross_entropy(output, label)return output, loss
可視化


ArcSoftmax需要更久的訓(xùn)練,這個(gè)收斂還不夠充分...顏值堪憂,另外ArcSoftmax經(jīng)常出現(xiàn)類別在特征空間分布不均勻的情況,這個(gè)也有點(diǎn)費(fèi)解,難道在訓(xùn)FR模型的時(shí)候先用softmax然后慢慢加margin有奇效?SphereFace那種退火的訓(xùn)練方式效果好會(huì)不會(huì)和這個(gè)有關(guān)呢...
Center Loss
亂入一個(gè)歐式距離的細(xì)作
公式推導(dǎo)

其中
是每個(gè)類別對(duì)應(yīng)的一個(gè)中心,在這里就是一個(gè)二維坐標(biāo)啦
Pytorch代碼實(shí)現(xiàn)
class centerloss(nn.Module):def __init__(self):super(centerloss, self).__init__()self.center = nn.Parameter(10 * torch.randn(10, 2))self.lamda = 0.2self.weight = nn.Parameter(torch.Tensor(2, 10)) # (input,output)nn.init.xavier_uniform_(self.weight)def forward(self, feature, label):batch_size = label.size()[0]nCenter = self.center.index_select(dim=0, index=label)distance = feature.dist(nCenter)centerloss = (1 / 2.0 / batch_size) * distanceout = feature.mm(self.weight)ceLoss = F.cross_entropy(out, label)return out, ceLoss + self.lamda * centerloss
這里實(shí)現(xiàn)的是center的部分,還要跟原始的CEloss相加的,具體看github吧
可視化


會(huì)不會(huì)配合weight norm效果更佳呢?以后再說吧...
總結(jié)
先寫到這里,如果大家有興趣可以去github點(diǎn)個(gè)star之類的...作為一個(gè)研一快結(jié)束的弱雞剛剛學(xué)會(huì)使用github...也是沒誰了...
參考文獻(xiàn):
Wang M, Deng W. Deep face recognition: A survey[J]. arXiv preprint arXiv:1804.06655, 2018.
好消息!
小白學(xué)視覺知識(shí)星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請(qǐng)按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~

