<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          提升分類模型acc(一):BatchSize&LARS

          共 7513字,需瀏覽 16分鐘

           ·

          2021-09-10 16:30


          【GiantPandaCV導(dǎo)讀】在使用大的bs訓(xùn)練情況下,會(huì)對(duì)精度有一定程度的損失,本文探討了訓(xùn)練的bs大小對(duì)精度的影響,同時(shí)探究Layer-wise Adaptive Rate Scaling(LARS)是否可以有效的提升精度。


          論文鏈接:https://arxiv.org/abs/1708.03888

          論文代碼: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
          知乎專欄: https://zhuanlan.zhihu.com/p/406882110

          1引言

          如何提升業(yè)務(wù)分類模型的性能,一直是個(gè)難題,畢竟沒(méi)有99.999%的性能都會(huì)帶來(lái)一定程度的風(fēng)險(xiǎn),所以很多時(shí)候我們只能通過(guò)控制閾值來(lái)調(diào)整準(zhǔn)召以達(dá)到想要的效果。本系列主要探究哪些模型trick和數(shù)據(jù)的方法可以大幅度讓你的分類性能更上一層樓,不過(guò)要注意一點(diǎn)的是,tirck不一定是適用于不同的數(shù)據(jù)場(chǎng)景的,但是數(shù)據(jù)處理方法是普適的。本篇文章主要是對(duì)于大的bs下訓(xùn)練分類模型的情況,如果bs比較小的可以忽略,直接看最后的結(jié)論就好了(這個(gè)系列以后的文章講述的方法是通用的,無(wú)論bs大小都可以用)。

          2實(shí)驗(yàn)配置

          • 模型:ResNet50
          • 數(shù)據(jù):ImageNet1k
          • 環(huán)境:8xV100

          3BatchSize對(duì)精度的影響

          所有的實(shí)驗(yàn)的超參都是統(tǒng)一的,warmup 5個(gè)epoch,訓(xùn)練90個(gè)epoch,StepLR進(jìn)行衰減,學(xué)習(xí)率的設(shè)置和bs線性相關(guān),公式為,優(yōu)化器使用帶有0.9的動(dòng)量的SGD,baselr為0.1(如果采用Adam或者AdamW優(yōu)化器的話,公式需要調(diào)整為),訓(xùn)練的數(shù)據(jù)增強(qiáng)只有RandomCropResize,RandomFlip,驗(yàn)證的數(shù)據(jù)增強(qiáng)為ResizeCenterCrop

          訓(xùn)練情況如下:

          • lr調(diào)整曲線如下:
          • 訓(xùn)練曲線如下:
          • 驗(yàn)證曲線如下:

          我這里設(shè)計(jì)了4組對(duì)照實(shí)驗(yàn),256, 1024, 2048和4096的batchsize,開(kāi)了FP16也只能跑到了4096了。采用的是分布式訓(xùn)練,所以單張卡的bs就是bs = total_bs / ngpus_per_node。這里我沒(méi)有使用跨卡bn,對(duì)于bs 64單卡來(lái)說(shuō)理論上已經(jīng)很大了,bn的作用是約束數(shù)據(jù)分布,64的bs已經(jīng)可以表達(dá)一個(gè)分布的subset了,再大的bs還是同分布的,意義不大,跨卡bn的速度也更慢,所以大的bs基本可以忽略這個(gè)問(wèn)題。但是對(duì)于檢測(cè)的任務(wù),跨卡bn還是有價(jià)值的,畢竟輸入的分辨率大,單卡的bs比較小,一般4,8,16,這時(shí)候統(tǒng)計(jì)更大的bn會(huì)對(duì)模型收斂更好。

          實(shí)驗(yàn)結(jié)果如下:

          模型epochLRbatchsizedataaugacc@top1
          ResNet50900.1256randomcropresize,randomflip76.422%
          ResNet50900.41024randomcropresize,randomflip76.228%
          ResNet50900.82048randomcropresize,randomflip76.132%
          ResNet50901.64096randomcropresize,randomflip75.75%

          很明顯可以看出來(lái),當(dāng)bs增加到4k的時(shí)候,acc下降了將近0.8%個(gè)點(diǎn),1k的時(shí)候,下降了0.2%個(gè)點(diǎn),所以,通常我們用大的bs訓(xùn)練的時(shí)候,是沒(méi)辦法達(dá)到最優(yōu)的精度的。個(gè)人建議,使用1k的bs和0.4的學(xué)習(xí)率最優(yōu)。

          4LARS(Layer-wise Adaptive Rate Scaling)

          1. 理論分析

          由于bs的增加,在同樣的epoch的情況下,會(huì)使網(wǎng)絡(luò)的weights更新迭代的次數(shù)變少,所以需要對(duì)LR隨著bs的增加而線性增加,但是這樣會(huì)導(dǎo)致上面我們看到的問(wèn)題,過(guò)大的lr會(huì)導(dǎo)致最終的收斂不穩(wěn)定,精度有所下降。

          LARS的出發(fā)點(diǎn)則是各個(gè)層的更新參數(shù)使用的學(xué)習(xí)率應(yīng)該根據(jù)自己的情況有所調(diào)整,而不是所有層使用相同的學(xué)習(xí)率,也就是每層有自己的local lr,所以有:

          這里,表示的是第幾層,表示的是超參數(shù),這個(gè)超參數(shù)遠(yuǎn)小于1,表示每層會(huì)改變參數(shù)的confidence,局部學(xué)習(xí)率可以很方便的替換每層的全局學(xué)習(xí)率,參數(shù)的更新大小為:

          與SGD聯(lián)合使用的算法如下:

          LARS代碼如下:

          class LARC(object):
              def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
                  self.optim = optimizer
                  self.trust_coefficient = trust_coefficient
                  self.eps = eps
                  self.clip = clip

              def step(self):
                  with torch.no_grad():
                      weight_decays = []
                      for group in self.optim.param_groups:
                          # absorb weight decay control from optimizer
                          weight_decay = group['weight_decay'if 'weight_decay' in group else 0
                          weight_decays.append(weight_decay)
                          group['weight_decay'] = 0
                          for p in group['params']:
                              if p.grad is None:
                                  continue
                              param_norm = torch.norm(p.data)
                              grad_norm = torch.norm(p.grad.data)

                              if param_norm != 0 and grad_norm != 0:
                                  # calculate adaptive lr + weight decay
                                  adaptive_lr = self.trust_coefficient * (param_norm) / (
                                              grad_norm + param_norm * weight_decay + self.eps)

                                  # clip learning rate for LARC
                                  if self.clip:
                                      # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
                                      adaptive_lr = min(adaptive_lr / group['lr'], 1)

                                  p.grad.data += weight_decay * p.data
                                  p.grad.data *= adaptive_lr

                  self.optim.step()
                  # return weight decay control to optimizer
                  for i, group in enumerate(self.optim.param_groups):
                      group['weight_decay'] = weight_decays[i]

          這里有一個(gè)超參數(shù),trust_coefficient,也就是公式里面所提到的, 這個(gè)參數(shù)對(duì)精度的影響比較大,實(shí)驗(yàn)部分我們會(huì)給出結(jié)論。

          2. 實(shí)驗(yàn)結(jié)論

          模型epochLRbatchsizedataaugacc@top1trust_confidence
          ResNet50900.41024randomcropresize,randomflip75.146%1e-3
          ResNet50900.82048randomcropresize,randomflip73.946%1e-3
          ResNet50901.64096randomcropresize,randomflip72.396%1e-3
          ResNet50900.41024randomcropresize,randomflip76.234%2e-2
          ResNet50900.82048randomcropresize,randomflip75.898%2e-2
          ResNet50901.64096randomcropresize,randomflip75.842%2e-2

          可以很明顯發(fā)現(xiàn),使用了LARS,設(shè)置turst_confidence為1e-3的情況下,有著明顯的掉點(diǎn),設(shè)置為2e-2的時(shí)候,在1k和4k的情況下,有著明顯的提升,但是2k的情況下有所下降。

          LARS一定程度上可以提升精度,但是強(qiáng)依賴超參,還是需要細(xì)致的調(diào)參訓(xùn)練。

          5結(jié)論

          • 8卡進(jìn)行分布式訓(xùn)練,使用1k的bs可以很好的平衡acc&speed。
          • LARS一定程度上可以提升精度,但是需要調(diào)參,做業(yè)務(wù)可以不用考慮,刷點(diǎn)的話要好好訓(xùn)練。

          6結(jié)束語(yǔ)

          本文是提升分類模型acc系列的第一篇,后續(xù)會(huì)講解一些通用的trick和數(shù)據(jù)處理的方法,敬請(qǐng)關(guān)注。




          END




             掃碼加交流群


          GiantPandaCV




          瀏覽 74
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  91在线无精精品秘 白丝 | 亚洲婷婷国产 | 成人高清无码视频在线免费观看 | 激情五月婷婷网 | 人人妻人人澡人人爽人人DVD |