深度學(xué)習(xí)的多個(gè) loss 是如何平衡的?
點(diǎn)擊上方“程序員大白”,選擇“星標(biāo)”公眾號(hào)
重磅干貨,第一時(shí)間送達(dá)
來(lái)自 | 知乎?
地址 | https://www.zhihu.com/question/375794498
編輯 | AI有道?
在一個(gè)端到端訓(xùn)練的網(wǎng)絡(luò)中,如果最終的loss = a*loss1+b*loss2+c*loss3...,對(duì)于a,b,c這些超參的選擇,有沒(méi)有什么方法?
作者:Evan
https://www.zhihu.com/question/375794498/answer/1052779937
其實(shí)這是目前深度學(xué)習(xí)領(lǐng)域被某種程度上忽視了的一個(gè)重要問(wèn)題,在近幾年大火的multi-task learning,generative adversarial networks, 等等很多機(jī)器學(xué)習(xí)任務(wù)和方法里面都會(huì)遇到,很多paper的做法都是暴力調(diào)參結(jié)果玄學(xué)……這里偷偷跟大家分享兩個(gè)很有趣的研究視角
1. 從預(yù)測(cè)不確定性的角度引入Bayesian框架,根據(jù)各個(gè)loss分量當(dāng)前的大小自動(dòng)設(shè)定其權(quán)重。有代表性的工作參見(jiàn)Alex Kendall等人的CVPR2018文章:
Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
https://arxiv.org/abs/1705.07115
文章的二作Yarin Gal是Zoubin Ghahramani的高徒,近幾年結(jié)合Bayesian思想和深度學(xué)習(xí)做了很多solid的工作。
2. 構(gòu)建所有l(wèi)oss的Pareto,以一次訓(xùn)練的超低代價(jià)得到多種超參組合對(duì)應(yīng)的結(jié)果。有代表性的工作參見(jiàn)Intel在2018年NeurIPS(對(duì),就是那個(gè)剛改了名字的機(jī)器學(xué)習(xí)頂會(huì))發(fā)表的:
Multi-Task Learning as Multi-Objective Optimization
http://papers.nips.cc/paper/7334-multi-task-learning-as-multi-objective-optimization
因?yàn)楦恼碌淖髡叨际抢鲜烊?,這里就不尬吹了,大家有興趣的可以仔細(xì)讀一讀,干貨滿滿。

作者:楊奎元-深動(dòng)
鏈接:https://www.zhihu.com/question/375794498/answer/1050963528
1. 一般都是多個(gè)loss之間平衡,即使是單任務(wù),也會(huì)有weight decay項(xiàng)。比較簡(jiǎn)單的組合一般通過(guò)調(diào)超參就可以。
2. 對(duì)于比較復(fù)雜的多任務(wù)loss之間平衡,這里推薦一篇通過(guò)網(wǎng)絡(luò)直接預(yù)測(cè)loss權(quán)重的方法[1]。以兩個(gè)loss為例,
?和
由網(wǎng)絡(luò)輸出,由于整體loss要求最小,所以前兩項(xiàng)希望
越大越好,為防止退化,最后第三項(xiàng)則希望
越小越好。當(dāng)兩個(gè)loss中某個(gè)比較大時(shí),其對(duì)應(yīng)的
也會(huì)取較大值,使得整體loss最小化,也就自然處理量綱不一致或某個(gè)loss方差較大問(wèn)題。

該方法后來(lái)被拓展到了物體檢測(cè)領(lǐng)域[2],用于考慮每個(gè)2D框標(biāo)注可能存在的不確定性問(wèn)題。

[1] Alex Kendall, Yarin Gal, Roberto Cipolla. Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics. CVPR, 2018.
[2] Yihui He, Chenchen Zhu, Jianren Wang, Marios Savvides, Xiangyu Zhang. Bounding Box Regression with Uncertainty for Accurate Object Detection. CVPR, 2019.
作者:鄭澤嘉
鏈接:https://www.zhihu.com/question/375794498/answer/1056695768
Focal loss 會(huì)根據(jù)每個(gè)task的表現(xiàn)幫你自動(dòng)調(diào)整這些參數(shù)的。
我們的做法一般是先分幾個(gè)stage 訓(xùn)練。stage ?0 : task 0, stage 1: task 0 and 1. 以此類推。在stage 1以后都用的是focal loss。
========== 沒(méi)想到我也可以二更 ===============
是這樣的。
首先對(duì)于每個(gè) Task,你有個(gè) Loss Function,以及一個(gè)映射到 [0, 1] 的 KPI (key performance indicator) 。比如對(duì)于分類任務(wù), Loss function 可以是 cross entropy loss,KPI 可以是 Accuracy 或者 Average Precision。對(duì)于 regression 來(lái)說(shuō)需要將 IOU 之類的歸一化到 [0, 1] 之間。KPI 越高表示這個(gè)任務(wù)表現(xiàn)越好。
對(duì)于每個(gè)進(jìn)來(lái)的 batch,每個(gè)Task_i 有個(gè) loss_i。每個(gè)Task i 還有個(gè)不同的 KPI: ?k_i。那根據(jù) Focal loss 的定義,F(xiàn)L(k_i, gamma_i) = -((1 - k_i)^gamma_i) * log(k_i)。一般來(lái)說(shuō)我們gamma 取 2。
于是對(duì)于這個(gè) batch 來(lái)說(shuō),整個(gè) loss = sum(FL(k_i, gamma_i) * loss_i)
在直觀上說(shuō),這個(gè) FL,當(dāng)一個(gè)任務(wù)的 KPI 接近 0 的時(shí)候會(huì)趨于無(wú)限大,使得你的 loss 完全被那個(gè)表現(xiàn)不好的 task 給 dominate。這樣你的back prop 就會(huì)讓所有的權(quán)重根據(jù)那個(gè)kpi 不好的任務(wù)調(diào)整。當(dāng)一個(gè)任務(wù)表現(xiàn)特別好 KPI 接近 1 的時(shí)候,F(xiàn)L 就會(huì)是0,在整個(gè) loss 里的比重也會(huì)變得很小。
當(dāng)然根據(jù)學(xué)習(xí)的速率不同有可能一開(kāi)始學(xué)的不好的task后面反超其他task。http://svl.stanford.edu/assets/papers/guo2018focus.pdf 這篇文章里講了如何像momentum 一樣的逐漸更新 KPI。
由于整個(gè) loss 里現(xiàn)在也要對(duì) KPI 求導(dǎo),所以文章里還有一些對(duì)于 KPI 求導(dǎo)的推導(dǎo)。
當(dāng)然我們也說(shuō)了,KPI 接近 0 時(shí),Loss 會(huì)變得很大,所以一開(kāi)始訓(xùn)練的時(shí)候不要用focal loss,要確保網(wǎng)絡(luò)的權(quán)重更新到一定時(shí)候再加入 focal loss。
希望大家訓(xùn)練愉快。
作者:Hanson
鏈接:https://www.zhihu.com/question/375794498/answer/1077922077
對(duì)于多任務(wù)學(xué)習(xí)而言,它每一組loss之間的數(shù)量級(jí)和學(xué)習(xí)難度并不一樣,尋找平衡點(diǎn)是個(gè)很難的事情。我舉兩個(gè)我在實(shí)際應(yīng)用中碰到的問(wèn)題。
第一個(gè)是多任務(wù)學(xué)習(xí)算法MTCNN,這算是人臉檢測(cè)領(lǐng)域最經(jīng)典的算法之一,被各家廠商魔改,其性能也是很不錯(cuò)的,也有很多版本的開(kāi)源實(shí)現(xiàn)(如果不了解的話,傳送門(mén))。但是我在測(cè)試各種實(shí)現(xiàn)的過(guò)程中,發(fā)現(xiàn)竟然沒(méi)有一套實(shí)現(xiàn)是超越了原版的。下圖中是不同版本的實(shí)現(xiàn),打了碼的是我復(fù)現(xiàn)的結(jié)果。

這是一件很困擾的事情,參數(shù)、網(wǎng)絡(luò)結(jié)構(gòu)大家設(shè)置都大差不差。但效果確實(shí)是迥異。
clsloss表示置信度score的loss,boxloss表示預(yù)測(cè)框位置box的loss,landmarksloss表示關(guān)鍵點(diǎn)位置landmarks的loss。
那么
這幾個(gè)權(quán)值,究竟應(yīng)該設(shè)置為什么樣的才能得到一個(gè)不錯(cuò)的結(jié)果呢?
其實(shí)有個(gè)比較不錯(cuò)的注意,就是只保留必要的那兩組權(quán)值,把另外一組設(shè)置為0,比如
。為什么這么做?第一是因?yàn)殛P(guān)鍵點(diǎn)的回歸在人臉檢測(cè)過(guò)程中不是必要的,去了這部分依舊沒(méi)什么大問(wèn)題,也只有在這個(gè)假設(shè)的前提下才能進(jìn)行接下來(lái)的實(shí)驗(yàn)。
就比如這個(gè)MTCNN中的ONet,它回歸了包括score、bbox、landmarks,我在用pytorch復(fù)現(xiàn)的時(shí)候,出現(xiàn)一些有意思的情況,就是將landmarks這條任務(wù)凍結(jié)后(即
),發(fā)現(xiàn)ONet的性能得到了巨大的提升。能超越原始版本的性能。
但是加上landmarks任務(wù)后(
)就會(huì)對(duì)cls_loss造成影響,這就是一個(gè)矛盾的現(xiàn)象。而且和a、b、c對(duì)應(yīng)的大小有很大關(guān)系。當(dāng)設(shè)置成(
)的時(shí)候關(guān)鍵點(diǎn)的精度真的是慘不忍睹,幾乎沒(méi)法用。當(dāng)設(shè)置成(
)的時(shí)候,loss到了同樣一個(gè)數(shù)量級(jí),landmarks的精度確實(shí)是上去了,但是score卻不怎么讓人滿意。如果產(chǎn)生了這種現(xiàn)象,就證明了這個(gè)網(wǎng)絡(luò)結(jié)構(gòu)在設(shè)計(jì)的時(shí)候出現(xiàn)了一些缺陷,需要去修改backbone之后的multi-task分支,讓兩者的相關(guān)性盡量減小?;蛘呤荗Net就不去做關(guān)鍵點(diǎn),而是選擇單獨(dú)的一個(gè)網(wǎng)絡(luò)去做關(guān)鍵點(diǎn)的預(yù)測(cè)(比如追加一個(gè)LNet)。box的回歸并不是特別受關(guān)鍵點(diǎn)影響,大部分情況box和landmarks是正向促進(jìn)的,影響程度可以看做和score是一致的,box的精度即便下降了5%,它還是能框得住目標(biāo),因此不用太在意。
上面這個(gè)實(shí)驗(yàn)意在說(shuō)明,要存在就好的loss權(quán)重組合,那么你的網(wǎng)絡(luò)結(jié)構(gòu)就必須設(shè)計(jì)的足夠好。不然你可能還需要通過(guò)上述的實(shí)驗(yàn)就驗(yàn)證你的網(wǎng)絡(luò)結(jié)構(gòu)。從多種策略的設(shè)計(jì)上去解決這中l(wèi)oss不均衡造成的困擾。
推薦閱讀
關(guān)于程序員大白
程序員大白是一群哈工大,東北大學(xué),西湖大學(xué)和上海交通大學(xué)的碩士博士運(yùn)營(yíng)維護(hù)的號(hào),大家樂(lè)于分享高質(zhì)量文章,喜歡總結(jié)知識(shí),歡迎關(guān)注[程序員大白],大家一起學(xué)習(xí)進(jìn)步!

