<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          BatchNorm的避坑指南(下)

          共 10114字,需瀏覽 21分鐘

           ·

          2021-06-24 11:24

          點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

          設(shè)為星標(biāo),干貨直達(dá)!

          Batch from Different Domains

          包含BatchNorm的模型訓(xùn)練過(guò)程包含兩個(gè)學(xué)習(xí)過(guò)程:一是模型主體參數(shù)是通過(guò)SGD學(xué)習(xí)得到的(SGD training),二是全局統(tǒng)計(jì)量是通過(guò)EMA或者PreciseBN從訓(xùn)練數(shù)據(jù)中學(xué)習(xí)得到(population statistics training)。當(dāng)訓(xùn)練數(shù)據(jù)和測(cè)試數(shù)據(jù)分布不同時(shí),我們稱之為domain shift,這個(gè)時(shí)候?qū)W習(xí)得到的全局統(tǒng)計(jì)量就可能會(huì)在測(cè)試時(shí)失效,這個(gè)問(wèn)題已經(jīng)有論文提出要采用Adaptive BatchNorm來(lái)解決,即在測(cè)試數(shù)據(jù)上重新計(jì)算全局統(tǒng)計(jì)量。這里還是以ResNet50為例(SGD batch size為1024,NBS為32),用ImageNet-C數(shù)據(jù)集(ImageNet的擾動(dòng)版本,共三種類型:contrast,gaussian noise和jpeg compression)來(lái)評(píng)估domain shift問(wèn)題,結(jié)果如下:


          從表中可以明顯看出:當(dāng)出現(xiàn)domain shift問(wèn)題后,采用Adaptive BatchNorm在target domain數(shù)據(jù)集上重新計(jì)算全局統(tǒng)計(jì)量可以提升模型效果。不過(guò)從表最后一行可以看到,如果在ImageNet驗(yàn)證集上重新計(jì)算統(tǒng)計(jì)量(直接采用inference-time預(yù)處理),最終效果要稍微差于原來(lái)結(jié)果(23.4 VS 23.8),這可能說(shuō)明如果不存在明顯的domain shift,原始處理方式是最好的。

          除了domain shift,訓(xùn)練數(shù)據(jù)存在multi-domain也會(huì)對(duì)BatchNorm產(chǎn)生影響,這個(gè)問(wèn)題更復(fù)雜了。這里以RetinaNet模型來(lái)說(shuō)明multi-domain的出現(xiàn)可能出現(xiàn)的問(wèn)題。RetinaNet的head包含4個(gè)卷積層以及最終的分類器和回歸器,其輸入是來(lái)自不同尺度的5個(gè)特征(),這可以kan'chehead在5個(gè)特征上是共享的,默認(rèn)head是不包含BatchNorm,當(dāng)我們?cè)诿總€(gè)卷積后加上BatchNorm后,問(wèn)題就變得復(fù)雜了。首先,首先就是訓(xùn)練過(guò)程mini-batch統(tǒng)計(jì)量的計(jì)算,明顯有兩種不同處理方式,一是對(duì)不同domain的特征輸入單獨(dú)計(jì)算mini-batch統(tǒng)計(jì)量并單獨(dú)歸一化,二是將所有domain的特征concat在一起,計(jì)算一個(gè)mini-batch統(tǒng)計(jì)量來(lái)歸一化。這兩種處理方式如下所示:


          這里記SGD訓(xùn)練過(guò)程中的兩種方式分別為domain-specific statisticsshared statistics。對(duì)于學(xué)習(xí)全局統(tǒng)計(jì)量,同樣存在對(duì)應(yīng)的兩種方式,即每個(gè)domain的特征單獨(dú)學(xué)習(xí)一套全局統(tǒng)計(jì)量,還是共享一套全局統(tǒng)計(jì)量。對(duì)于BatchNorm的affine transform layer也存在兩種選擇:每個(gè)domain一套參數(shù)還是共享參數(shù)。不同組合的模型效果如下表所示:


          從表中結(jié)果可以總結(jié)兩個(gè)結(jié)論:(1)SGD training和population statistics training保持一致非常重要,此時(shí)都可以取得較好的結(jié)果(行1,行4和行6);(2)affine transform layer無(wú)論單獨(dú)參數(shù)還是共享基本不影響結(jié)果。這里的一個(gè)小插曲是如果直接在head中加上BatchNorm,其實(shí)對(duì)應(yīng)的是行3,其實(shí)這是因?yàn)椴煌叨鹊奶卣魇切蛄刑幚淼模@就造成了SGD training其實(shí)是domain-specific的,此時(shí)效果就較差,所以大部分實(shí)現(xiàn)中要不然不用norm,要不然就用BatchNorm。不同組合的實(shí)現(xiàn)代碼如下:

          # 簡(jiǎn)單地加上BN,注意forward時(shí),不同特征是串行處理的,那么SGD training其實(shí)是domain-specific的,
          # 但是只維持一套全局統(tǒng)計(jì)量,所以測(cè)試時(shí)又是共享的
          class RetinaNetHead_Row3:
              def __init__(self, num_conv, channel):
                  head = []
                  for _ in range(num_conv):
                      head.append(nn.Conv2d(channel, channel, 3))
                      head.append(nn.BatchNorm2d(channel))
                  self.head = nn.Sequential(?head)
              def forward(self, inputs: List[Tensor]):
                  return [self.head(i) for i in inputs]

          # 如果要共享,那么在forward時(shí)對(duì)特征進(jìn)行concat來(lái)統(tǒng)一計(jì)算并歸一化 
          class RetinaNetHead_Row1(RetinaNetHead_Row3):
              def forward(self, inputs: List[Tensor]):
                  for mod in self.head:
                      if isinstance(mod, nn.BatchNorm2d):
                          # for BN layer, normalize all inputs together
                          shapes = [i.shape for i in inputs]
                          spatial_sizes = [s[2] ? s[3for s in shapes]
                          x = [i.flatten(2for i in inputs]
                          x = torch.cat(x, dim=2).unsqueeze(3)
                          x = mod(x).split(spatial_sizes, dim=2)
                          inputs = [i.view(s) for s, i in zip(shapes, x)]
                      else:
                          # for conv layer, apply it separately
                          inputs = [mod(i) for i in inputs]
                  return inputs

          # 另外一種簡(jiǎn)單的處理方式是每個(gè)特征采用各自的BN
          class RetinaNetHead_Row6:
              def __init__(self, num_conv, channel, num_features):
                  # num_features: number of features coming from
                  # different FPN levels, e.g. 5
                  heads = [[] for _ in range(num_levels)]
                  for _ in range(num_conv):
                      conv = nn.Conv2d(channel, channel, 3)
                      for h in heads:
                          # add a shared conv and a domain?specific BN
                          h.extend([conv, nn.BatchNorm2d(channel)])
                  self.heads = [nn.Sequential(?h) for h in heads]
              def forward(self, inputs: List[Tensor]):
                  # end up with one head for each input
                  return [head(i) for head, i in
                      zip(self.heads, inputs)]

          對(duì)于行2和行4,可以通過(guò)訓(xùn)練好的行1和行3模型重新在訓(xùn)練數(shù)據(jù)上計(jì)算domain-specific全局統(tǒng)計(jì)量即可,在實(shí)現(xiàn)時(shí),可以如下:

          class CycleBatchNormList(nn.ModuleList):
              """
              A hacky way to implement domain-specific BatchNorm
              if it's guaranteed that a fixed number of domains will be
              called with fixed order.
              """


              def __init__(self, length, channels):
                  super().__init__([nn.BatchNorm2d(channels, affine=Falsefor k in range(length)])
                  # shared affine, domain-specific BN
                  self.weight = nn.Parameter(torch.ones(channels))
                  self.bias = nn.Parameter(torch.zeros(channels))
                  self._pos = 0

              def forward(self, x):
                  ret = self[self._pos](x)
                  self._pos = (self._pos + 1) % len(self)

                  w = self.weight.reshape(1-111)
                  b = self.bias.reshape(1-111)
                  return ret * w + b

          # 訓(xùn)練好模型,我們可以重新將BN層換成以上實(shí)現(xiàn),就可以在訓(xùn)練數(shù)據(jù)上重新計(jì)算domain-specific全局統(tǒng)計(jì)量

          RetinaNet面臨的其實(shí)是特征層面的multi-domain問(wèn)題,而且每個(gè)batch中的各個(gè)domain是均勻的。如果是數(shù)據(jù)層面的multi-domain,其面臨的問(wèn)題將會(huì)復(fù)雜,此時(shí)domain的分布比例也是多變的,但是總的原則是盡量減少不一致性,因?yàn)?strong style="line-height: 1.6 !important;">consistency is crucial。

          Information Leakage within a Batch

          BatchNorm面臨的另外一個(gè)挑戰(zhàn),就是可能出現(xiàn)信息泄露,這里所說(shuō)的信息泄露指的是模型學(xué)習(xí)到了利用mini-batch的信息來(lái)做預(yù)測(cè),而這些其實(shí)并不是我們要學(xué)習(xí)的,因?yàn)檫@樣模型可能難以對(duì)mini-batch里的每個(gè)sample單獨(dú)做預(yù)測(cè)。


          比如BatchNorm的作者曾做過(guò)這樣一個(gè)實(shí)驗(yàn),在ResNet50的訓(xùn)練過(guò)程中,NBS=32,但是保證里面包含16個(gè)類別,每個(gè)類別有2個(gè)圖像,這樣一種特殊的設(shè)計(jì)要模型在訓(xùn)練過(guò)程中強(qiáng)制記憶了這種模式(可能是每個(gè)mini-batch中必須有同類別存在),那么在測(cè)試時(shí)如果輸入不是這種設(shè)計(jì),效果就會(huì)變差。這個(gè)在驗(yàn)證集上不同處理結(jié)果如上所示,可以看到測(cè)試時(shí)無(wú)論是采用全局統(tǒng)計(jì)量還是random mini-batch統(tǒng)計(jì)量,效果均較差,但是如果采用和訓(xùn)練過(guò)程同樣的模式,效果就比較好。這其實(shí)也從側(cè)面說(shuō)明保持一致是多么的重要。

          前面說(shuō)過(guò),如果在R-CNN的head加入BatchNorm,那么在測(cè)試時(shí)采用mini-batch統(tǒng)計(jì)量會(huì)比全局統(tǒng)計(jì)量會(huì)效果更好,這里面其實(shí)也存在信息泄露的問(wèn)題。對(duì)于每個(gè)GPU只有一個(gè)image的情況,每個(gè)mini-batch里面的RoIs全部來(lái)自于一個(gè)圖像,這時(shí)候模型就可能依賴mini-batch來(lái)做預(yù)測(cè),那么測(cè)試時(shí)采用全局統(tǒng)計(jì)量效果就會(huì)差了,對(duì)于每個(gè)GPU有多個(gè)圖像時(shí),情況還稍好一些,所以原來(lái)的結(jié)果中單卡單圖像效果最差。一種解決方案是采用shuffle BN,就是head進(jìn)行處理前,先隨機(jī)打亂所有卡上的RoIs特征,每個(gè)卡分配隨機(jī)的RoIs,這樣就避免前面那個(gè)可能出現(xiàn)的信息泄露,head處理完后再shuffle回來(lái),具體處理流程如下所示:


          這個(gè)具體的代碼實(shí)現(xiàn)見mask_rcnn_BNhead_shuffle.py。其實(shí)在MoCo中也使用了shuffle BN來(lái)防止信息泄露。另外還是可以采用SyncBN來(lái)避免這種問(wèn)題。具體的對(duì)比結(jié)果如下所示,可以看到采用shuffle BN和SyncBN均可以避免信息泄露,得到較好的效果。注意shuffle BN的 cross-GPU synchronization要比SyncBN要少,效率更高一些。


          另外一個(gè)常見的場(chǎng)景是對(duì)比學(xué)習(xí)中信息泄露,因?yàn)閷?duì)比學(xué)習(xí)往往需要對(duì)同一個(gè)圖像做不同的augmentations來(lái)作為正樣本,這其實(shí)一個(gè)sample既當(dāng)輸入又當(dāng)目標(biāo),mini-batch可能會(huì)泄露信息導(dǎo)致模型學(xué)習(xí)不到好的特征。MoCo采用的是shuffle BN,而SimCLR和BYOL采用的是SyncBN。另外曠視提出的Momentum^2 Teacher來(lái)采用moving average statistics來(lái)防止信息泄露。

          總結(jié)

          一個(gè)簡(jiǎn)單的BatchNorm,如果我們使用不當(dāng),往往會(huì)出現(xiàn)一些讓人意料的結(jié)果,所以要謹(jǐn)慎處理。總結(jié)來(lái)看,主要有如下結(jié)論和指南:

          • 模型在未收斂時(shí)使用EMA統(tǒng)計(jì)量來(lái)評(píng)估模型是不穩(wěn)定的,一種替代方案是PreciseBN;
          • BatchNorm本身存在訓(xùn)練和測(cè)試的不一致性,特別是NBS較少時(shí),這種不一致會(huì)更強(qiáng),可用的方案是測(cè)試時(shí)也采用mini-batch統(tǒng)計(jì)量或者采用FrozenBN;
          • 在domain shift場(chǎng)景中,最好基于target domain數(shù)據(jù)重新計(jì)算全局統(tǒng)計(jì)量,在multi-domain數(shù)據(jù)訓(xùn)練時(shí),要特別注意處理的一致性;
          • BatchNorm會(huì)存在信息泄露的風(fēng)險(xiǎn),這處理mini-batch時(shí)要防止特殊的出現(xiàn)。

          參考

          • Rethinking "Batch" in BatchNorm

          • detectron2/projects/Rethinking-BatchNorm

          推薦閱讀

          CPVT:一個(gè)卷積就可以隱式編碼位置信息

          SOTA模型Swin Transformer是如何煉成的!

          谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!

          DETR:基于 Transformers 的目標(biāo)檢測(cè)

          目標(biāo)跟蹤入門篇-相關(guān)濾波

          SOTA模型Swin Transformer是如何煉成的!

          MoCo V3:我并不是你想的那樣!

          Transformer在語(yǔ)義分割上的應(yīng)用

          "未來(lái)"的經(jīng)典之作ViT:transformer is all you need!

          PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!

          漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA

          Transformer為何能闖入CV界秒殺CNN?

          不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!


          機(jī)器學(xué)習(xí)算法工程師


                                              一個(gè)用心的公眾號(hào)


          瀏覽 162
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  免费精品久久久久久中文字幕-无删减 | 波多野结衣中文字幕久久 | 欧美成人三级精品 | 2024AV天堂 | 91久久精品人人槡人妻人 |