提升分類模型acc(一):BatchSize&LARS
【GiantPandaCV導(dǎo)讀】在使用大的bs訓(xùn)練情況下,會(huì)對(duì)精度有一定程度的損失,本文探討了訓(xùn)練的bs大小對(duì)精度的影響,同時(shí)探究Layer-wise Adaptive Rate Scaling(LARS)是否可以有效的提升精度。
論文鏈接:https://arxiv.org/abs/1708.03888論文代碼: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
知乎專欄: https://zhuanlan.zhihu.com/p/406882110
1引言
如何提升業(yè)務(wù)分類模型的性能,一直是個(gè)難題,畢竟沒(méi)有99.999%的性能都會(huì)帶來(lái)一定程度的風(fēng)險(xiǎn),所以很多時(shí)候我們只能通過(guò)控制閾值來(lái)調(diào)整準(zhǔn)召以達(dá)到想要的效果。本系列主要探究哪些模型trick和數(shù)據(jù)的方法可以大幅度讓你的分類性能更上一層樓,不過(guò)要注意一點(diǎn)的是,tirck不一定是適用于不同的數(shù)據(jù)場(chǎng)景的,但是數(shù)據(jù)處理方法是普適的。本篇文章主要是對(duì)于大的bs下訓(xùn)練分類模型的情況,如果bs比較小的可以忽略,直接看最后的結(jié)論就好了(這個(gè)系列以后的文章講述的方法是通用的,無(wú)論bs大小都可以用)。
2實(shí)驗(yàn)配置
模型:ResNet50 數(shù)據(jù):ImageNet1k 環(huán)境:8xV100
3BatchSize對(duì)精度的影響
所有的實(shí)驗(yàn)的超參都是統(tǒng)一的,warmup 5個(gè)epoch,訓(xùn)練90個(gè)epoch,StepLR進(jìn)行衰減,學(xué)習(xí)率的設(shè)置和bs線性相關(guān),公式為,優(yōu)化器使用帶有0.9的動(dòng)量的SGD,baselr為0.1(如果采用Adam或者AdamW優(yōu)化器的話,公式需要調(diào)整為),訓(xùn)練的數(shù)據(jù)增強(qiáng)只有RandomCropResize,RandomFlip,驗(yàn)證的數(shù)據(jù)增強(qiáng)為Resize和CenterCrop。
訓(xùn)練情況如下:
lr調(diào)整曲線如下: 
訓(xùn)練曲線如下: 
驗(yàn)證曲線如下: 
我這里設(shè)計(jì)了4組對(duì)照實(shí)驗(yàn),256, 1024, 2048和4096的batchsize,開(kāi)了FP16也只能跑到了4096了。采用的是分布式訓(xùn)練,所以單張卡的bs就是bs = total_bs / ngpus_per_node。這里我沒(méi)有使用跨卡bn,對(duì)于bs 64單卡來(lái)說(shuō)理論上已經(jīng)很大了,bn的作用是約束數(shù)據(jù)分布,64的bs已經(jīng)可以表達(dá)一個(gè)分布的subset了,再大的bs還是同分布的,意義不大,跨卡bn的速度也更慢,所以大的bs基本可以忽略這個(gè)問(wèn)題。但是對(duì)于檢測(cè)的任務(wù),跨卡bn還是有價(jià)值的,畢竟輸入的分辨率大,單卡的bs比較小,一般4,8,16,這時(shí)候統(tǒng)計(jì)更大的bn會(huì)對(duì)模型收斂更好。
實(shí)驗(yàn)結(jié)果如下:
| 模型 | epoch | LR | batchsize | dataaug | acc@top1 |
|---|---|---|---|---|---|
| ResNet50 | 90 | 0.1 | 256 | randomcropresize,randomflip | 76.422% |
| ResNet50 | 90 | 0.4 | 1024 | randomcropresize,randomflip | 76.228% |
| ResNet50 | 90 | 0.8 | 2048 | randomcropresize,randomflip | 76.132% |
| ResNet50 | 90 | 1.6 | 4096 | randomcropresize,randomflip | 75.75% |
很明顯可以看出來(lái),當(dāng)bs增加到4k的時(shí)候,acc下降了將近0.8%個(gè)點(diǎn),1k的時(shí)候,下降了0.2%個(gè)點(diǎn),所以,通常我們用大的bs訓(xùn)練的時(shí)候,是沒(méi)辦法達(dá)到最優(yōu)的精度的。個(gè)人建議,使用1k的bs和0.4的學(xué)習(xí)率最優(yōu)。
4LARS(Layer-wise Adaptive Rate Scaling)
1. 理論分析
由于bs的增加,在同樣的epoch的情況下,會(huì)使網(wǎng)絡(luò)的weights更新迭代的次數(shù)變少,所以需要對(duì)LR隨著bs的增加而線性增加,但是這樣會(huì)導(dǎo)致上面我們看到的問(wèn)題,過(guò)大的lr會(huì)導(dǎo)致最終的收斂不穩(wěn)定,精度有所下降。
LARS的出發(fā)點(diǎn)則是各個(gè)層的更新參數(shù)使用的學(xué)習(xí)率應(yīng)該根據(jù)自己的情況有所調(diào)整,而不是所有層使用相同的學(xué)習(xí)率,也就是每層有自己的local lr,所以有:
這里,表示的是第幾層,表示的是超參數(shù),這個(gè)超參數(shù)遠(yuǎn)小于1,表示每層會(huì)改變參數(shù)的confidence,局部學(xué)習(xí)率可以很方便的替換每層的全局學(xué)習(xí)率,參數(shù)的更新大小為:
與SGD聯(lián)合使用的算法如下:
LARS代碼如下:
class LARC(object):
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = eps
self.clip = clip
def step(self):
with torch.no_grad():
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
weight_decays.append(weight_decay)
group['weight_decay'] = 0
for p in group['params']:
if p.grad is None:
continue
param_norm = torch.norm(p.data)
grad_norm = torch.norm(p.grad.data)
if param_norm != 0 and grad_norm != 0:
# calculate adaptive lr + weight decay
adaptive_lr = self.trust_coefficient * (param_norm) / (
grad_norm + param_norm * weight_decay + self.eps)
# clip learning rate for LARC
if self.clip:
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
adaptive_lr = min(adaptive_lr / group['lr'], 1)
p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[i]
這里有一個(gè)超參數(shù),trust_coefficient,也就是公式里面所提到的, 這個(gè)參數(shù)對(duì)精度的影響比較大,實(shí)驗(yàn)部分我們會(huì)給出結(jié)論。
2. 實(shí)驗(yàn)結(jié)論
| 模型 | epoch | LR | batchsize | dataaug | acc@top1 | trust_confidence |
|---|---|---|---|---|---|---|
| ResNet50 | 90 | 0.4 | 1024 | randomcropresize,randomflip | 75.146% | 1e-3 |
| ResNet50 | 90 | 0.8 | 2048 | randomcropresize,randomflip | 73.946% | 1e-3 |
| ResNet50 | 90 | 1.6 | 4096 | randomcropresize,randomflip | 72.396% | 1e-3 |
| ResNet50 | 90 | 0.4 | 1024 | randomcropresize,randomflip | 76.234% | 2e-2 |
| ResNet50 | 90 | 0.8 | 2048 | randomcropresize,randomflip | 75.898% | 2e-2 |
| ResNet50 | 90 | 1.6 | 4096 | randomcropresize,randomflip | 75.842% | 2e-2 |
可以很明顯發(fā)現(xiàn),使用了LARS,設(shè)置turst_confidence為1e-3的情況下,有著明顯的掉點(diǎn),設(shè)置為2e-2的時(shí)候,在1k和4k的情況下,有著明顯的提升,但是2k的情況下有所下降。
LARS一定程度上可以提升精度,但是強(qiáng)依賴超參,還是需要細(xì)致的調(diào)參訓(xùn)練。
5結(jié)論
8卡進(jìn)行分布式訓(xùn)練,使用1k的bs可以很好的平衡acc&speed。 LARS一定程度上可以提升精度,但是需要調(diào)參,做業(yè)務(wù)可以不用考慮,刷點(diǎn)的話要好好訓(xùn)練。
6結(jié)束語(yǔ)
本文是提升分類模型acc系列的第一篇,后續(xù)會(huì)講解一些通用的trick和數(shù)據(jù)處理的方法,敬請(qǐng)關(guān)注。

END
掃碼加交流群
GiantPandaCV
