<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          CVPR2021-Representative BatchNorm

          共 10889字,需瀏覽 22分鐘

           ·

          2021-04-14 00:08

          源代碼地址:https://github.com/ShangHua-Gao/RBN論文地址: Representative Batch Normalization with Feature Calibration

          引言

          BatchNorm模塊能讓模型訓(xùn)練更加穩(wěn)定,因而被廣泛使用。它的中心化以及縮放步驟需要依賴樣本統(tǒng)計(jì)得到的均值和方差,而這也導(dǎo)致了在歸一化的過程,忽視了各個(gè)實(shí)例的區(qū)別。其中,中心化步驟是為了增強(qiáng)信息特征,減少噪聲。而縮放步驟是為了讓特征服從一個(gè)穩(wěn)定的分布。考慮到不同實(shí)例有不同特點(diǎn),我們引入了簡(jiǎn)單有效的特征校準(zhǔn)步驟(feature calibration scheme),改進(jìn)得到Representative BatchNorm,在各大圖像任務(wù)均有一定的提升。

          BN的缺點(diǎn)

          BatchNorm公式如下,它將特征縮放為一個(gè)均值為0,方差為1的分布

          BatchNorm的一個(gè)前提是,我們假定了不同實(shí)例對(duì)應(yīng)的特征都服從相同的分布。但實(shí)際中,存在以下兩種情況不滿足上述的假設(shè):

          1. 一個(gè)mini-batch里的統(tǒng)計(jì)信息(均值,方差)與總的訓(xùn)練集/測(cè)試集的統(tǒng)計(jì)信息不一致
          2. 測(cè)試集中的數(shù)據(jù)實(shí)例不符合訓(xùn)練集的分布

          針對(duì)第一點(diǎn),BatchNorm在batchsize比較小的情況下,統(tǒng)計(jì)得到的均值和方差不夠準(zhǔn)確,相比其他Normalize方法(如GroupNorm)表現(xiàn)的很差。

          而針對(duì)第二點(diǎn),因?yàn)樵谕评磉^程中使用的是訓(xùn)練過程中統(tǒng)計(jì)更新的running-meanrunning-variance。若測(cè)試集不與訓(xùn)練集在一個(gè)分布下,在BN后,它不一定服從的是均值為0,方差為1的分布。

          針對(duì)不同情況,對(duì)模型的影響也不同

          • 當(dāng)測(cè)試集的均值小于running-mean,BN會(huì)錯(cuò)誤地移除掉具有代表性的特征
          • 當(dāng)測(cè)試集的均值大于running-mean,BN會(huì)“漏掉”特征中的噪聲
          • 當(dāng)測(cè)試集的方差小于running-var,BN會(huì)導(dǎo)致特征的intensity過小
          • 當(dāng)測(cè)試集的方差大于running-var,BN會(huì)導(dǎo)致特征的intensity過大

          個(gè)人理解這里的intensity指的是特征強(qiáng)度,可能比較抽象,一方面指的是特征值的范圍,另一方面也可以指特征的變化劇烈強(qiáng)度

          為了解決上述的問題,一個(gè)很自然的想法是怎么將各個(gè)數(shù)據(jù)實(shí)例的特征,與mini-batch統(tǒng)計(jì)信息很好的結(jié)合在一起。一方面也能讓特征處在穩(wěn)定的分布,另一方面也能根據(jù)各個(gè)實(shí)例的特點(diǎn)進(jìn)行進(jìn)一步調(diào)整

          Representative Batch Normalization

          為了解決上述問題,我們提出了RBN,其中RBN也分為兩個(gè)步驟,一個(gè)是中心化校準(zhǔn)(Centering Calibration),一個(gè)是縮放校準(zhǔn)(Scaling Calibration)

          Centering Calibration

          我們先看下公式

          在對(duì)X求均值的時(shí)候,我們先對(duì)其做一個(gè)變換

          其中 是輸入特征, 則是一個(gè)形狀為(N, C, 1, 1)的可學(xué)習(xí)變量, 則是表示各個(gè)實(shí)例的特征,它可以有多種shape(只要是合理的變換,能表征各個(gè)實(shí)例的特征即可),這里我們對(duì)輸入使用一個(gè)全局平均池化來得到實(shí)例特征,因此形狀為(N, C, 1, 1)。

          我們首先將實(shí)例特征與可學(xué)習(xí)變量相乘,最后與輸入進(jìn)行相加

          公式推導(dǎo)

          對(duì)于使用全局平均池化得到實(shí)例特征,我們有如下的公式

          因?yàn)楹罄m(xù)我們要對(duì)變換后的X求均值(在BN里是對(duì)N,H,W這三個(gè)維度求均值),對(duì)于 來說,已經(jīng)是X對(duì)HW維度上求過均值了,后續(xù)不過是在N的維度上再求一次均值。所以我們有

          我們針對(duì)變換后的X求均值,有

          然后我們來對(duì)比一下該變換帶來的差異

          我們將兩個(gè)進(jìn)行相減比較差異

          可以看到,當(dāng) 的絕對(duì)值接近于0,的差值接近于0,說明此時(shí)還是依賴于batch內(nèi)的統(tǒng)計(jì)信息。當(dāng) 的絕對(duì)值較大,具體可以分以下兩種情況來考慮

          • 當(dāng) 大于0,且 > ,此時(shí)Representative Feature得到增強(qiáng),反之亦然
          • 當(dāng) 小于0,且 > ,此時(shí)特征噪聲會(huì)抑制,反之亦然

          Scaling Calibration

          我們?cè)?strong style="font-weight: bold;color: black;">BN后,拉伸調(diào)整之前做一次縮放對(duì)齊

          公式如下:

          其中 是兩個(gè)可學(xué)習(xí)參數(shù),用于拉伸平移(跟BN的兩個(gè)可學(xué)習(xí)參數(shù)效果類似) 跟前面的類似,是一個(gè)實(shí)例特征,這里還是用全局平均池化得到。 則是一個(gè)限制函數(shù),可以使用各種范數(shù)來限制,這里采用的是 sigmoid 函數(shù)來限制值域

          公式推導(dǎo)

          我們的限制函數(shù)是 sigmoid,于是有

          那么我們可以找到一個(gè)  滿足

          可以看到我們的方差因?yàn)橄拗坪瘮?shù)而變得更小了,讓分布更加的均勻

          各通道均值的標(biāo)準(zhǔn)差比較

          整體流程

          首先對(duì)輸入做中心校準(zhǔn)

          然后就是熟悉的減均值,除方差

          接著是做縮放校準(zhǔn)

          最后是做拉伸,偏移,得到最終結(jié)果

          實(shí)驗(yàn)對(duì)比

          作者在主流的網(wǎng)絡(luò)里測(cè)試了常見的Normalize模塊,并進(jìn)行對(duì)比,可以看到提升還是比較顯著的另外也通過消融實(shí)驗(yàn)證明均值校準(zhǔn)和縮放校準(zhǔn)的有效性,另外更多實(shí)驗(yàn)可以看下原文。

          代碼

          作者也開放了對(duì)應(yīng)的Pytorch源碼

          import torch.nn as nn
          import math
          import torch
          import numpy as np
          import torch.nn.functional as F
          class RepresentativeBatchNorm2d(nn.BatchNorm2d):
              def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                           track_running_stats=True)
          :

                  super(RepresentativeBatchNorm2d, self).__init__(
                      num_features, eps, momentum, affine, track_running_stats)
                  self.num_features = num_features
                  ### weights for affine transformation in BatchNorm ###
                  if self.affine:
                      self.weight = nn.Parameter(torch.Tensor(1, num_features, 11))
                      self.bias = nn.Parameter(torch.Tensor(1, num_features, 11))
                      self.weight.data.fill_(1)
                      self.bias.data.fill_(0)
                  else:
                      self.register_parameter('weight'None)
                      self.register_parameter('bias'None)

                  ### weights for centering calibration ###        
                  self.center_weight = nn.Parameter(torch.Tensor(1, num_features, 11))
                  self.center_weight.data.fill_(0)
                  ### weights for scaling calibration ###            
                  self.scale_weight = nn.Parameter(torch.Tensor(1, num_features, 11))
                  self.scale_bias = nn.Parameter(torch.Tensor(1, num_features, 11))
                  self.scale_weight.data.fill_(0)
                  self.scale_bias.data.fill_(1)
                  ### calculate statistics ###
                  self.stas = nn.AdaptiveAvgPool2d((1,1))

              def forward(self, input):
                  self._check_input_dim(input)

                  ####### centering calibration begin ####### 
                  input += self.center_weight.view(1,self.num_features,1,1)*self.stas(input)
                  ####### centering calibration end ####### 

                  ####### BatchNorm begin #######         
                  if self.momentum is None:
                      exponential_average_factor = 0.0
                  else:
                      exponential_average_factor = self.momentum

                  if self.training and self.track_running_stats:
                      if self.num_batches_tracked is not None:
                          self.num_batches_tracked = self.num_batches_tracked + 1
                          if self.momentum is None:  # use cumulative moving average
                              exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                          else
                              exponential_average_factor = self.momentum
                  output = F.batch_norm(
                      input, self.running_mean, self.running_var, NoneNone,
                      self.training or not self.track_running_stats,
                      exponential_average_factor, self.eps)
                  ####### BatchNorm end #######

                  ####### scaling calibration begin ####### 
                  scale_factor = torch.sigmoid(self.scale_weight*self.stas(output)+self.scale_bias)
                  ####### scaling calibration end ####### 
                  if self.affine:
                      return self.weight*scale_factor*output + self.bias
                  else:
                      return scale_factor*output

          其中大部分代碼跟Pytorch自己實(shí)現(xiàn)的BatchNorm類似,我們簡(jiǎn)單關(guān)注幾點(diǎn)

          首先在初始化里,初始化了中心校準(zhǔn),縮放校準(zhǔn)所需的可學(xué)習(xí)參數(shù),并填充默認(rèn)值

          ### weights for centering calibration ###        
          self.center_weight = nn.Parameter(torch.Tensor(1, num_features, 11))
          self.center_weight.data.fill_(0)
          ### weights for scaling calibration ###            
          self.scale_weight = nn.Parameter(torch.Tensor(1, num_features, 11))
          self.scale_bias = nn.Parameter(torch.Tensor(1, num_features, 11))
          self.scale_weight.data.fill_(0)
          self.scale_bias.data.fill_(1)

          我們經(jīng)常會(huì)把可學(xué)習(xí)參數(shù)中,權(quán)重w初始化為1,偏置b初始化為0,而這里恰恰相反,將權(quán)重則初始化為0,偏置則為1。個(gè)人推測(cè)可以參考推導(dǎo)Centering Calibration中,當(dāng)w為0時(shí),則等價(jià)于原始的BN,從而后續(xù)讓模型根據(jù)需要來去調(diào)整w。但為什么偏置設(shè)為1,筆者沒想清楚。可以參考RBN開源工程的issue1,地址在這篇文章開頭

          然后是初始化我們的實(shí)例特征提取操作,這里是用一個(gè)全局池化

          ### calculate statistics ###
          self.stas = nn.AdaptiveAvgPool2d((1,1))

          在forward函數(shù)一開始,我們先做中心校準(zhǔn)操作

          ####### centering calibration begin ####### 
          input += self.center_weight.view(1,self.num_features,1,1)*self.stas(input)
          ####### centering calibration end ####### 

          然后是調(diào)用torch自帶的Batchnorm

          ...
          output = F.batch_norm(
                      input, self.running_mean, self.running_var, NoneNone,
                      self.training or not self.track_running_stats,
                      exponential_average_factor, self.eps)

          接著做縮放校準(zhǔn)操作

          ####### scaling calibration begin ####### 
          scale_factor = torch.sigmoid(self.scale_weight*self.stas(output)+self.scale_bias)
          ####### scaling calibration end #######

          最后根據(jù)屬性 self.affine 做最后的拉伸和偏移

          if self.affine:
              return self.weight*scale_factor*output + self.bias
          else:
              return scale_factor*output

          總結(jié)

          作者提出了一種簡(jiǎn)單有效的方法,將BN層的mini-batch的統(tǒng)計(jì)特征和各個(gè)實(shí)例獨(dú)自的特征(Representative也就體現(xiàn)在這里)巧妙的結(jié)合起來,使得能夠更好自適應(yīng)集合里的數(shù)據(jù),最后各個(gè)實(shí)驗(yàn)也證明了其有效性。期待更多在Norm方面的工作~


          歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅(jiān)持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識(shí)。( ? ?ω?? )?

          有對(duì)文章相關(guān)的問題,或者想要加入交流群,歡迎添加BBuf微信:

          二維碼


          瀏覽 103
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲精品免费观看 | 无码一区二区黑人猛烈视频网站 | 人人摸人人草人人 | 黄色视频网站亚洲 | 国产精品夜夜爽7777777 |