<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          BatchNorm的避坑指南(上)

          共 8755字,需瀏覽 18分鐘

           ·

          2021-06-18 10:44

          點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

          設(shè)為星標(biāo),干貨直達(dá)!



          BatchNorm作為一種特征歸一化方法基本是CNN網(wǎng)絡(luò)的標(biāo)配。BatchNorm可以加快模型收斂速度,防止過(guò)擬合,對(duì)學(xué)習(xí)速率更魯棒,但是BatchNorm由于在batch上進(jìn)行操作,如果使用不當(dāng)可能會(huì)帶來(lái)副作用。近期Facebook AI的論文Rethinking "Batch" in BatchNorm系統(tǒng)且全面地對(duì)BatchNorm可能會(huì)帶來(lái)的問(wèn)題做了總結(jié),同時(shí)也給出了一些規(guī)避方案和建議,堪稱一份“避坑指南”。

          BatchNorm

          BatchNorm主要在CNN網(wǎng)絡(luò)中應(yīng)用,對(duì)于NLP領(lǐng)域,常采用的transformer采用的是LayerNorm,所以這里只討論BatchNorm2D。在訓(xùn)練階段,對(duì)于shape為的mini-batch ,BatchNorm首先計(jì)算各個(gè)channel的均值和方差

          然后BatchNorm對(duì)shape為特征進(jìn)行歸一化:

          可以看到計(jì)算均值和方差是依賴batch的,這也就是BatchNorm的名字由來(lái)。在測(cè)試階段,BatchNorm采用的均值和方差是從訓(xùn)練過(guò)程估計(jì)的全局統(tǒng)計(jì)量(population statistics):,這兩個(gè)參數(shù)也是從訓(xùn)練數(shù)據(jù)學(xué)習(xí)到的參數(shù)(但不是可訓(xùn)練參數(shù),沒(méi)有BP過(guò)程)。常規(guī)的做法在訓(xùn)練階段采用EMA( exponential moving average,指數(shù)移動(dòng)平均,在TensorFlow中EMA產(chǎn)生的均值和方差稱為moving_meanmoving_var,而PyTorch則稱為running_meanrunning_var)來(lái)估計(jì):

          訓(xùn)練階段采用的是mini-batch統(tǒng)計(jì)量,而測(cè)試階段是采用全局統(tǒng)計(jì)量,這就造成了BatchNorm的訓(xùn)練和測(cè)試不一致問(wèn)題,這個(gè)后面會(huì)詳細(xì)討論。

          除了歸一化,BatchNorm還包含對(duì)各個(gè)channel的特征做affine transform(增加特征表征能力):

          這里的是可訓(xùn)練的參數(shù),但是這個(gè)過(guò)程其實(shí)沒(méi)有batch的參與,從實(shí)現(xiàn)上等價(jià)于額外增加一個(gè)depthwise 1 × 1卷積層。BatchNorm的麻煩主要來(lái)自于mini-batch統(tǒng)計(jì)量的計(jì)算和歸一化中,這個(gè)affine transform不是影響因素,所以后面的討論主要集中在前面。

          圍繞著batch所能帶來(lái)的問(wèn)題,論文共討論了BatchNorm的四個(gè)方面:

          • Population Statistics:EMA是否能夠準(zhǔn)確估計(jì)全局統(tǒng)計(jì)量以及PreciseBN;
          • Batch in Training and Testing:訓(xùn)練采用mini-batch統(tǒng)計(jì)量,而測(cè)試采用全局統(tǒng)計(jì)量,由此帶來(lái)的不一致問(wèn)題;
          • Batch from Different Domains:BatchNorm在multiple domains中遇到的問(wèn)題;
          • Information Leakage within a Batch:BatchNorm所導(dǎo)致的信息泄露問(wèn)題;

          第二個(gè)應(yīng)該是大家都熟知的問(wèn)題,但是其實(shí)BatchNorm可能影響的方面是很多的,如域適應(yīng)(domain adaptation)和對(duì)比學(xué)習(xí)中信息泄露問(wèn)題。另外,這里討論的4個(gè)方面也不是獨(dú)立的,它們往往交織在一起。

          Population Statistics

          訓(xùn)練過(guò)程中的均值和方差是mini-batch計(jì)算出來(lái)的,但是在推理階段往往是每次只處理一個(gè)sample,沒(méi)有辦法再計(jì)算依賴batch的統(tǒng)計(jì)量。BatchNorm采用的方法是訓(xùn)練過(guò)程中用EMA估計(jì)全局統(tǒng)計(jì)量,但是這個(gè)估計(jì)可能會(huì)夠好:當(dāng)較大時(shí),每個(gè)iteration中mini-batch的統(tǒng)計(jì)量對(duì)全局統(tǒng)計(jì)量貢獻(xiàn)很少,這會(huì)導(dǎo)致更新過(guò)慢;當(dāng)較大時(shí),每個(gè)iteration中mini-batch的統(tǒng)計(jì)量會(huì)起主導(dǎo)作用,導(dǎo)致估計(jì)值不能代表全局。一般情況取一個(gè)較大的值,如0.9或0.99,這是一個(gè)超參數(shù)。論文中在ResNet50的訓(xùn)練過(guò)程(256 GPU,每個(gè)GPU batch_size=32)隨機(jī)選擇模型的某個(gè)BatchNorm層的某個(gè)channel,繪制了其EMA mean以及population mean,這里的population mean采用當(dāng)前模型在100 mini-batches的batch mean的平均值來(lái)估計(jì),這個(gè)可以代表當(dāng)前模型的全局統(tǒng)計(jì)量,對(duì)比圖如下所示。在訓(xùn)練前期,從圖a可以看到EMA mean和當(dāng)前的batch mean是有距離的,而圖b說(shuō)明EMA mean是落后于當(dāng)前模型的近似全局統(tǒng)計(jì)量的,但是到訓(xùn)練中后期EMA mean就比較準(zhǔn)確了。


          這說(shuō)明EMA統(tǒng)計(jì)量在訓(xùn)練早期是有偏差的。一個(gè)準(zhǔn)確的全局統(tǒng)計(jì)量應(yīng)該是:使用整個(gè)訓(xùn)練集作為一個(gè)batch計(jì)算特征的均值和方差,但是這個(gè)計(jì)算成本太高了,論文中提出采用一種近似方法來(lái)計(jì)算:首先采用固定模型(訓(xùn)練好的)計(jì)算很多mini-batch;然后聚合每個(gè)mini-batch的統(tǒng)計(jì)量來(lái)得到全局統(tǒng)計(jì)量。假定共需要計(jì)算個(gè)samples,batch_size為,那么共計(jì)算個(gè)mini-batch,記它們的統(tǒng)計(jì)量為,那么全局統(tǒng)計(jì)量可以近似這樣計(jì)算:

          這其實(shí)只是一種聚合方式,論文附錄也討論了其它計(jì)算方式,結(jié)果是類似的。這種BatchNorm稱為PreciseBN,具體代碼實(shí)現(xiàn)可以參考fvcore.nn.precise_bn:

          class _PopulationVarianceEstimator:
              """
              Alternatively, one can estimate population variance by the sample variance
              of all batches combined. This needs to use the batch size of each batch
              in this function to undo the bessel-correction.
              This produces better estimation when each batch is small.
              See Appendix of the paper "Rethinking Batch in BatchNorm" for details.
              In this implementation, we also take into account varying batch sizes.
              A batch of N1 samples with a mean of M1 and a batch of N2 samples with a
              mean of M2 will produce a population mean of (N1M1+N2M2)/(N1+N2) instead
              of (M1+M2)/2.
              """


              def __init__(self, mean_buffer: torch.Tensor, var_buffer: torch.Tensor) -> None:
                  self.pop_mean: torch.Tensor = torch.zeros_like(mean_buffer) # population mean
                  self.pop_square_mean: torch.Tensor = torch.zeros_like(var_buffer) # population variance 
                  self.tot = 0 # total samples
              
              # update per mini-batch, is called by `update_bn_stats`
              def update(
                  self, batch_mean: torch.Tensor, batch_var: torch.Tensor, batch_size: int
              )
           -> None:

                  self.tot += batch_size
                  batch_square_mean = batch_mean.square() + batch_var * (
                      (batch_size - 1) / batch_size
                  )
                  self.pop_mean += (batch_mean - self.pop_mean) * (batch_size / self.tot)
                  self.pop_square_mean += (batch_square_mean - self.pop_square_mean) * (
                      batch_size / self.tot
                  )

              @property
              def pop_var(self) -> torch.Tensor:
                  return self.pop_square_mean - self.pop_mean.square()

          論文中以ResNet50的訓(xùn)練為例對(duì)比了EMA和PreciseBN的效果,如下圖所示,可以看到PreciseBN比EMA效果更加穩(wěn)定,特別是訓(xùn)練早期(此時(shí)模型未收斂),雖然最終兩者的效果接近。


          進(jìn)一步地,如果訓(xùn)練采用更大的batch size,實(shí)驗(yàn)發(fā)現(xiàn)EMA在訓(xùn)練過(guò)程中的抖動(dòng)更大,但此時(shí)PreciseBN效果比較穩(wěn)定。當(dāng)采用larger batch訓(xùn)練時(shí),學(xué)習(xí)速率增大,而且EMA更新次數(shù)減少,這些都會(huì)對(duì)EMA產(chǎn)生較大影響。綜上,雖然EMA和PreciseBN最終效果接近(因此EMA的缺點(diǎn)往往被忽視),但是在模型未收斂的訓(xùn)練早期,PreciseBN更加穩(wěn)定,像強(qiáng)化學(xué)習(xí)需要在訓(xùn)練早期評(píng)估模型效果這種場(chǎng)景,PreciseBN能帶來(lái)更加穩(wěn)定可靠的結(jié)果。


          此外,論文也通過(guò)實(shí)驗(yàn)證明了PreciseBN只需要 samples就可以得到比較好的結(jié)果,以ImageNet訓(xùn)練為例,采用PreciseBN評(píng)估只需要增加0.5%的訓(xùn)練時(shí)間。


          另外,論文里還對(duì)比了batch size對(duì)PreciseBN的影響。這里先理清楚兩個(gè)概念:(1)normalization batch size(NBS):實(shí)際計(jì)算統(tǒng)計(jì)量的mini-batch的size;(2)total batch size或者SGD batch size:每個(gè)iteration中mini-batch的size,或者說(shuō)每執(zhí)行一次SGD算法的batch size;兩者在多卡訓(xùn)練過(guò)程是不等同的(此時(shí)NBS是per-GPU batch size,而SyncBN可以實(shí)現(xiàn)兩者一致)。從結(jié)果來(lái)看,NBS較小時(shí),模型效果會(huì)變差,但是PreciseBN的batch size是相對(duì)NBS獨(dú)立的,所以選擇batch size 時(shí)PreciseBN可以取得穩(wěn)定的效果,并且在NBS較小時(shí)相比EMA提升效果


          Batch in Training and Testing

          前面已經(jīng)說(shuō)過(guò)BatchNorm在訓(xùn)練時(shí)采用的是mini-batch統(tǒng)計(jì)量,而測(cè)試時(shí)采用的全局統(tǒng)計(jì)量,這就導(dǎo)致了訓(xùn)練和測(cè)試的不一致性,從而帶來(lái)對(duì)模型性能的影響。為此,論文還是以ResNet50訓(xùn)練為例分析這種不一致帶來(lái)的影響,這里還同時(shí)比較了不同NBS帶來(lái)的差異(SGD batch size固定在1024,此時(shí)NBS從2~1024變化),分別對(duì)比不同NBS下的三個(gè)指標(biāo):(1)采用mini-batch統(tǒng)計(jì)量在訓(xùn)練集上的分類誤差;(2)采用mini-batch統(tǒng)計(jì)量在驗(yàn)證集上的分類誤差;(3)采用全局統(tǒng)計(jì)量在驗(yàn)證集上的分類誤差。這里(1)和(3)其實(shí)是常規(guī)評(píng)估方法,而(2)往往不會(huì)這樣做,但是(1)和(2)就保持一致了(訓(xùn)練和測(cè)試均采用mini-batch統(tǒng)計(jì)量)。對(duì)比結(jié)果如下所示,從中可以得到三個(gè)方面的結(jié)論:

          • training noise:訓(xùn)練集誤差隨著NBS增大而減少,這主要是由于SGD訓(xùn)練噪音所導(dǎo)致的,當(dāng)NBS較小時(shí),mini-batch統(tǒng)計(jì)量波動(dòng)大導(dǎo)致優(yōu)化困難,從而產(chǎn)生較大的訓(xùn)練誤差;
          • generalization gap:對(duì)比(1)和(2),兩者均采用mini-batch統(tǒng)計(jì)量,差異就來(lái)自數(shù)據(jù)集不同,這部分性能差異就是泛化gap;
          • train-test inconsistency:對(duì)比(2)和(3),兩者數(shù)據(jù)集一樣,但是(2)采用mini-batch統(tǒng)計(jì)量,而(3)采用全局統(tǒng)計(jì)量,這部分性能差異就是訓(xùn)練和測(cè)試不一致所導(dǎo)致的;

          另外,我們可以看到當(dāng)NBS較小時(shí),(2)和(3)的性能差距是比較大的,這說(shuō)明如果訓(xùn)練的NBS比較小時(shí)在測(cè)試時(shí)采用mini-batch統(tǒng)計(jì)量效果會(huì)更好,此時(shí)保持一致是比較重要的(這點(diǎn)至關(guān)重要)。當(dāng)NBS較大時(shí),(2)和(3)的差異就比較小,此時(shí)mini-batch統(tǒng)計(jì)量越來(lái)越接近全局統(tǒng)計(jì)量。

          雖然NBS較小時(shí),在測(cè)試時(shí)采用mini-batch統(tǒng)計(jì)量效果更好,但是在實(shí)際場(chǎng)景中幾乎不會(huì)這樣處理(一般情況下都是每次處理一個(gè)樣本)。不過(guò)還是有一些特例,比如兩階段檢測(cè)模型R-CNN中,R-CNN的head輸入是每個(gè)圖像的一系列region-of-interest (RoIs),一般情況下一個(gè)圖像會(huì)有個(gè)RoIs,實(shí)際情況這些RoIs是組成batch進(jìn)行處理的,訓(xùn)練過(guò)程是所有圖像的RoIs,而測(cè)試時(shí)是單個(gè)圖像的RoIs組成batch,在這種情況中測(cè)試時(shí)就可以選擇mini-batch統(tǒng)計(jì)量。這里以Mask R-CNN為實(shí)驗(yàn)?zāi)P停瑢⒛J(rèn)的2fc box head(2個(gè)全連接層)換成4conv1fc head(4個(gè)卷積層加一個(gè),并且在box head和mask head的每個(gè)卷積層后面都加上BatchNorm層,實(shí)驗(yàn)結(jié)果如下,可以看到采用mini-batch統(tǒng)計(jì)量是優(yōu)于采用全局統(tǒng)計(jì)量的,另外在訓(xùn)練過(guò)程中每個(gè)GPU只用一張圖像時(shí),此時(shí)測(cè)試時(shí)采用全局統(tǒng)計(jì)量效果會(huì)很差,這里有另外的過(guò)擬合問(wèn)題存在,后面再述(BatchNorm導(dǎo)致的信息泄露)。另外R-CNN的head還存在另外的一種訓(xùn)練和測(cè)試的inconsistency:訓(xùn)練時(shí)mini-batch是正負(fù)樣本抽樣的,而測(cè)試時(shí)是按score選取的topK,mini-batch的分布就發(fā)生了變化。


          另外一個(gè)避免訓(xùn)練和測(cè)試的inconsistency可選方案是訓(xùn)練也采用全局統(tǒng)計(jì)量,常用的方案是Frozen BatchNorm (FrozenBN)(訓(xùn)練中直接采用EMA統(tǒng)計(jì)量模型無(wú)法訓(xùn)練),F(xiàn)rozenBN指的是采用一個(gè)提前算好的固定全局統(tǒng)計(jì)量,此時(shí)BatchNorm的訓(xùn)練優(yōu)化就只有一個(gè)linear transform了。FrozenBN采用的情景是將一個(gè)已經(jīng)訓(xùn)練好的模型遷移到其它任務(wù),如在ImageNet訓(xùn)練的ResNet模型在遷移到下游檢測(cè)任務(wù)時(shí)一般采用FrozenBN。不過(guò)我們也可以在模型的訓(xùn)練過(guò)程中采用FrozenBN,論文中還是以ResNet50為例,在前80個(gè)epoch采用正常的BN訓(xùn)練,在后20個(gè)epoch采用FrozenBN,對(duì)比效果如下,可以看到FrozenBN在NBS較小時(shí)也是表現(xiàn)較好,優(yōu)于測(cè)試時(shí)采用mini-batch統(tǒng)計(jì)量,這不失為一種好的方案。這里值得注意的是當(dāng)NBS較大時(shí),F(xiàn)rozenBN還是測(cè)試時(shí)采用mini-batch統(tǒng)計(jì)量均差于常規(guī)方案(BN訓(xùn)練,測(cè)試時(shí)采用全局統(tǒng)計(jì)量)。



          推薦閱讀

          CPVT:一個(gè)卷積就可以隱式編碼位置信息

          SOTA模型Swin Transformer是如何煉成的!

          谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!

          DETR:基于 Transformers 的目標(biāo)檢測(cè)

          目標(biāo)跟蹤入門(mén)篇-相關(guān)濾波

          SOTA模型Swin Transformer是如何煉成的!

          MoCo V3:我并不是你想的那樣!

          Transformer在語(yǔ)義分割上的應(yīng)用

          "未來(lái)"的經(jīng)典之作ViT:transformer is all you need!

          PVT:可用于密集任務(wù)backbone的金字塔視覺(jué)transformer!

          漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA

          Transformer為何能闖入CV界秒殺CNN?

          不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!


          機(jī)器學(xué)習(xí)算法工程師


                                              一個(gè)用心的公眾號(hào)


          瀏覽 117
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  国产深夜福利 | 操出水视频在线观看网站国产 | 国产精品一久久 | 成人国产一区二区三区精品麻豆 | 白丝暴肛在线观看91 |