深度學(xué)習(xí)的多個(gè) loss 是如何平衡的?
點(diǎn)擊上方“小白學(xué)視覺(jué)”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
來(lái)自 | 知乎
地址 | https://www.zhihu.com/question/375794498
編輯 | AI有道
在一個(gè)端到端訓(xùn)練的網(wǎng)絡(luò)中,如果最終的loss = a*loss1+b*loss2+c*loss3...,對(duì)于a,b,c這些超參的選擇,有沒(méi)有什么方法?
作者:Evan
https://www.zhihu.com/question/375794498/answer/1052779937
其實(shí)這是目前深度學(xué)習(xí)領(lǐng)域被某種程度上忽視了的一個(gè)重要問(wèn)題,在近幾年大火的multi-task learning,generative adversarial networks, 等等很多機(jī)器學(xué)習(xí)任務(wù)和方法里面都會(huì)遇到,很多paper的做法都是暴力調(diào)參結(jié)果玄學(xué)……這里偷偷跟大家分享兩個(gè)很有趣的研究視角
1. 從預(yù)測(cè)不確定性的角度引入Bayesian框架,根據(jù)各個(gè)loss分量當(dāng)前的大小自動(dòng)設(shè)定其權(quán)重。有代表性的工作參見Alex Kendall等人的CVPR2018文章:
Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
https://arxiv.org/abs/1705.07115
文章的二作Yarin Gal是Zoubin Ghahramani的高徒,近幾年結(jié)合Bayesian思想和深度學(xué)習(xí)做了很多solid的工作。
2. 構(gòu)建所有l(wèi)oss的Pareto,以一次訓(xùn)練的超低代價(jià)得到多種超參組合對(duì)應(yīng)的結(jié)果。有代表性的工作參見Intel在2018年NeurIPS(對(duì),就是那個(gè)剛改了名字的機(jī)器學(xué)習(xí)頂會(huì))發(fā)表的:
Multi-Task Learning as Multi-Objective Optimization
http://papers.nips.cc/paper/7334-multi-task-learning-as-multi-objective-optimization
因?yàn)楦恼碌淖髡叨际抢鲜烊?,這里就不尬吹了,大家有興趣的可以仔細(xì)讀一讀,干貨滿滿。

作者:楊奎元-深動(dòng)
鏈接:https://www.zhihu.com/question/375794498/answer/1050963528
1. 一般都是多個(gè)loss之間平衡,即使是單任務(wù),也會(huì)有weight decay項(xiàng)。比較簡(jiǎn)單的組合一般通過(guò)調(diào)超參就可以。
2. 對(duì)于比較復(fù)雜的多任務(wù)loss之間平衡,這里推薦一篇通過(guò)網(wǎng)絡(luò)直接預(yù)測(cè)loss權(quán)重的方法[1]。以兩個(gè)loss為例,
和
由網(wǎng)絡(luò)輸出,由于整體loss要求最小,所以前兩項(xiàng)希望
越大越好,為防止退化,最后第三項(xiàng)則希望
越小越好。當(dāng)兩個(gè)loss中某個(gè)比較大時(shí),其對(duì)應(yīng)的
也會(huì)取較大值,使得整體loss最小化,也就自然處理量綱不一致或某個(gè)loss方差較大問(wèn)題。

該方法后來(lái)被拓展到了物體檢測(cè)領(lǐng)域[2],用于考慮每個(gè)2D框標(biāo)注可能存在的不確定性問(wèn)題。

[1] Alex Kendall, Yarin Gal, Roberto Cipolla. Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics. CVPR, 2018.
[2] Yihui He, Chenchen Zhu, Jianren Wang, Marios Savvides, Xiangyu Zhang. Bounding Box Regression with Uncertainty for Accurate Object Detection. CVPR, 2019.
作者:鄭澤嘉
鏈接:https://www.zhihu.com/question/375794498/answer/1056695768
Focal loss 會(huì)根據(jù)每個(gè)task的表現(xiàn)幫你自動(dòng)調(diào)整這些參數(shù)的。
我們的做法一般是先分幾個(gè)stage 訓(xùn)練。stage 0 : task 0, stage 1: task 0 and 1. 以此類推。在stage 1以后都用的是focal loss。
========== 沒(méi)想到我也可以二更 ===============
是這樣的。
首先對(duì)于每個(gè) Task,你有個(gè) Loss Function,以及一個(gè)映射到 [0, 1] 的 KPI (key performance indicator) 。比如對(duì)于分類任務(wù), Loss function 可以是 cross entropy loss,KPI 可以是 Accuracy 或者 Average Precision。對(duì)于 regression 來(lái)說(shuō)需要將 IOU 之類的歸一化到 [0, 1] 之間。KPI 越高表示這個(gè)任務(wù)表現(xiàn)越好。
對(duì)于每個(gè)進(jìn)來(lái)的 batch,每個(gè)Task_i 有個(gè) loss_i。每個(gè)Task i 還有個(gè)不同的 KPI: k_i。那根據(jù) Focal loss 的定義,F(xiàn)L(k_i, gamma_i) = -((1 - k_i)^gamma_i) * log(k_i)。一般來(lái)說(shuō)我們gamma 取 2。
于是對(duì)于這個(gè) batch 來(lái)說(shuō),整個(gè) loss = sum(FL(k_i, gamma_i) * loss_i)
在直觀上說(shuō),這個(gè) FL,當(dāng)一個(gè)任務(wù)的 KPI 接近 0 的時(shí)候會(huì)趨于無(wú)限大,使得你的 loss 完全被那個(gè)表現(xiàn)不好的 task 給 dominate。這樣你的back prop 就會(huì)讓所有的權(quán)重根據(jù)那個(gè)kpi 不好的任務(wù)調(diào)整。當(dāng)一個(gè)任務(wù)表現(xiàn)特別好 KPI 接近 1 的時(shí)候,F(xiàn)L 就會(huì)是0,在整個(gè) loss 里的比重也會(huì)變得很小。
當(dāng)然根據(jù)學(xué)習(xí)的速率不同有可能一開始學(xué)的不好的task后面反超其他task。http://svl.stanford.edu/assets/papers/guo2018focus.pdf 這篇文章里講了如何像momentum 一樣的逐漸更新 KPI。
由于整個(gè) loss 里現(xiàn)在也要對(duì) KPI 求導(dǎo),所以文章里還有一些對(duì)于 KPI 求導(dǎo)的推導(dǎo)。
當(dāng)然我們也說(shuō)了,KPI 接近 0 時(shí),Loss 會(huì)變得很大,所以一開始訓(xùn)練的時(shí)候不要用focal loss,要確保網(wǎng)絡(luò)的權(quán)重更新到一定時(shí)候再加入 focal loss。
希望大家訓(xùn)練愉快。
作者:Hanson
鏈接:https://www.zhihu.com/question/375794498/answer/1077922077
對(duì)于多任務(wù)學(xué)習(xí)而言,它每一組loss之間的數(shù)量級(jí)和學(xué)習(xí)難度并不一樣,尋找平衡點(diǎn)是個(gè)很難的事情。我舉兩個(gè)我在實(shí)際應(yīng)用中碰到的問(wèn)題。
第一個(gè)是多任務(wù)學(xué)習(xí)算法MTCNN,這算是人臉檢測(cè)領(lǐng)域最經(jīng)典的算法之一,被各家廠商魔改,其性能也是很不錯(cuò)的,也有很多版本的開源實(shí)現(xiàn)(如果不了解的話,傳送門)。但是我在測(cè)試各種實(shí)現(xiàn)的過(guò)程中,發(fā)現(xiàn)竟然沒(méi)有一套實(shí)現(xiàn)是超越了原版的。下圖中是不同版本的實(shí)現(xiàn),打了碼的是我復(fù)現(xiàn)的結(jié)果。

這是一件很困擾的事情,參數(shù)、網(wǎng)絡(luò)結(jié)構(gòu)大家設(shè)置都大差不差。但效果確實(shí)是迥異。
clsloss表示置信度score的loss,boxloss表示預(yù)測(cè)框位置box的loss,landmarksloss表示關(guān)鍵點(diǎn)位置landmarks的loss。
那么
這幾個(gè)權(quán)值,究竟應(yīng)該設(shè)置為什么樣的才能得到一個(gè)不錯(cuò)的結(jié)果呢?
其實(shí)有個(gè)比較不錯(cuò)的注意,就是只保留必要的那兩組權(quán)值,把另外一組設(shè)置為0,比如
。為什么這么做?第一是因?yàn)殛P(guān)鍵點(diǎn)的回歸在人臉檢測(cè)過(guò)程中不是必要的,去了這部分依舊沒(méi)什么大問(wèn)題,也只有在這個(gè)假設(shè)的前提下才能進(jìn)行接下來(lái)的實(shí)驗(yàn)。
就比如這個(gè)MTCNN中的ONet,它回歸了包括score、bbox、landmarks,我在用pytorch復(fù)現(xiàn)的時(shí)候,出現(xiàn)一些有意思的情況,就是將landmarks這條任務(wù)凍結(jié)后(即
),發(fā)現(xiàn)ONet的性能得到了巨大的提升。能超越原始版本的性能。
但是加上landmarks任務(wù)后(
)就會(huì)對(duì)cls_loss造成影響,這就是一個(gè)矛盾的現(xiàn)象。而且和a、b、c對(duì)應(yīng)的大小有很大關(guān)系。當(dāng)設(shè)置成(
)的時(shí)候關(guān)鍵點(diǎn)的精度真的是慘不忍睹,幾乎沒(méi)法用。當(dāng)設(shè)置成(
)的時(shí)候,loss到了同樣一個(gè)數(shù)量級(jí),landmarks的精度確實(shí)是上去了,但是score卻不怎么讓人滿意。如果產(chǎn)生了這種現(xiàn)象,就證明了這個(gè)網(wǎng)絡(luò)結(jié)構(gòu)在設(shè)計(jì)的時(shí)候出現(xiàn)了一些缺陷,需要去修改backbone之后的multi-task分支,讓兩者的相關(guān)性盡量減小?;蛘呤荗Net就不去做關(guān)鍵點(diǎn),而是選擇單獨(dú)的一個(gè)網(wǎng)絡(luò)去做關(guān)鍵點(diǎn)的預(yù)測(cè)(比如追加一個(gè)LNet)。box的回歸并不是特別受關(guān)鍵點(diǎn)影響,大部分情況box和landmarks是正向促進(jìn)的,影響程度可以看做和score是一致的,box的精度即便下降了5%,它還是能框得住目標(biāo),因此不用太在意。
上面這個(gè)實(shí)驗(yàn)意在說(shuō)明,要存在就好的loss權(quán)重組合,那么你的網(wǎng)絡(luò)結(jié)構(gòu)就必須設(shè)計(jì)的足夠好。不然你可能還需要通過(guò)上述的實(shí)驗(yàn)就驗(yàn)證你的網(wǎng)絡(luò)結(jié)構(gòu)。從多種策略的設(shè)計(jì)上去解決這中l(wèi)oss不均衡造成的困擾。
好消息!
小白學(xué)視覺(jué)知識(shí)星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺(jué)、目標(biāo)跟蹤、生物視覺(jué)、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺(jué)實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺(jué)。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺(jué)、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺(jué)SLAM“。請(qǐng)按照格式備注,否則不予通過(guò)。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~

