CVPR2021-Representative BatchNorm
源代碼地址:https://github.com/ShangHua-Gao/RBN論文地址: Representative Batch Normalization with Feature Calibration
引言
BatchNorm模塊能讓模型訓(xùn)練更加穩(wěn)定,因而被廣泛使用。它的中心化以及縮放步驟需要依賴樣本統(tǒng)計(jì)得到的均值和方差,而這也導(dǎo)致了在歸一化的過程,忽視了各個(gè)實(shí)例的區(qū)別。其中,中心化步驟是為了增強(qiáng)信息特征,減少噪聲。而縮放步驟是為了讓特征服從一個(gè)穩(wěn)定的分布。考慮到不同實(shí)例有不同特點(diǎn),我們引入了簡(jiǎn)單有效的特征校準(zhǔn)步驟(feature calibration scheme),改進(jìn)得到Representative BatchNorm,在各大圖像任務(wù)均有一定的提升。
BN的缺點(diǎn)
BatchNorm公式如下,它將特征縮放為一個(gè)均值為0,方差為1的分布
BatchNorm的一個(gè)前提是,我們假定了不同實(shí)例對(duì)應(yīng)的特征都服從相同的分布。但實(shí)際中,存在以下兩種情況不滿足上述的假設(shè):
一個(gè)mini-batch里的統(tǒng)計(jì)信息(均值,方差)與總的訓(xùn)練集/測(cè)試集的統(tǒng)計(jì)信息不一致 測(cè)試集中的數(shù)據(jù)實(shí)例不符合訓(xùn)練集的分布
針對(duì)第一點(diǎn),BatchNorm在batchsize比較小的情況下,統(tǒng)計(jì)得到的均值和方差不夠準(zhǔn)確,相比其他Normalize方法(如GroupNorm)表現(xiàn)的很差。
而針對(duì)第二點(diǎn),因?yàn)樵谕评磉^程中使用的是訓(xùn)練過程中統(tǒng)計(jì)更新的running-mean和running-variance。若測(cè)試集不與訓(xùn)練集在一個(gè)分布下,在BN后,它不一定服從的是均值為0,方差為1的分布。
針對(duì)不同情況,對(duì)模型的影響也不同
當(dāng)測(cè)試集的均值小于running-mean,BN會(huì)錯(cuò)誤地移除掉具有代表性的特征 當(dāng)測(cè)試集的均值大于running-mean,BN會(huì)“漏掉”特征中的噪聲 當(dāng)測(cè)試集的方差小于running-var,BN會(huì)導(dǎo)致特征的intensity過小 當(dāng)測(cè)試集的方差大于running-var,BN會(huì)導(dǎo)致特征的intensity過大
個(gè)人理解這里的intensity指的是特征強(qiáng)度,可能比較抽象,一方面指的是特征值的范圍,另一方面也可以指特征的變化劇烈強(qiáng)度
為了解決上述的問題,一個(gè)很自然的想法是怎么將各個(gè)數(shù)據(jù)實(shí)例的特征,與mini-batch統(tǒng)計(jì)信息很好的結(jié)合在一起。一方面也能讓特征處在穩(wěn)定的分布,另一方面也能根據(jù)各個(gè)實(shí)例的特點(diǎn)進(jìn)行進(jìn)一步調(diào)整
Representative Batch Normalization
為了解決上述問題,我們提出了RBN,其中RBN也分為兩個(gè)步驟,一個(gè)是中心化校準(zhǔn)(Centering Calibration),一個(gè)是縮放校準(zhǔn)(Scaling Calibration)
Centering Calibration
我們先看下公式
在對(duì)X求均值的時(shí)候,我們先對(duì)其做一個(gè)變換
其中 是輸入特征, 則是一個(gè)形狀為(N, C, 1, 1)的可學(xué)習(xí)變量, 則是表示各個(gè)實(shí)例的特征,它可以有多種shape(只要是合理的變換,能表征各個(gè)實(shí)例的特征即可),這里我們對(duì)輸入使用一個(gè)全局平均池化來得到實(shí)例特征,因此形狀為(N, C, 1, 1)。
我們首先將實(shí)例特征與可學(xué)習(xí)變量相乘,最后與輸入進(jìn)行相加
公式推導(dǎo)
對(duì)于使用全局平均池化得到實(shí)例特征,我們有如下的公式
因?yàn)楹罄m(xù)我們要對(duì)變換后的X求均值(在BN里是對(duì)N,H,W這三個(gè)維度求均值),對(duì)于 來說,已經(jīng)是X對(duì)HW維度上求過均值了,后續(xù)不過是在N的維度上再求一次均值。所以我們有
我們針對(duì)變換后的X求均值,有
然后我們來對(duì)比一下該變換帶來的差異
我們將兩個(gè)進(jìn)行相減比較差異
可以看到,當(dāng) 的絕對(duì)值接近于0, 和 的差值接近于0,說明此時(shí)還是依賴于batch內(nèi)的統(tǒng)計(jì)信息。當(dāng) 的絕對(duì)值較大,具體可以分以下兩種情況來考慮
當(dāng) 大于0,且 > ,此時(shí)Representative Feature得到增強(qiáng),反之亦然 當(dāng) 小于0,且 > ,此時(shí)特征噪聲會(huì)抑制,反之亦然
Scaling Calibration
我們?cè)?strong style="font-weight: bold;color: black;">BN后,拉伸調(diào)整之前做一次縮放對(duì)齊
公式如下:
其中 和 是兩個(gè)可學(xué)習(xí)參數(shù),用于拉伸平移(跟BN的兩個(gè)可學(xué)習(xí)參數(shù)效果類似) 跟前面的類似,是一個(gè)實(shí)例特征,這里還是用全局平均池化得到。 則是一個(gè)限制函數(shù),可以使用各種范數(shù)來限制,這里采用的是 sigmoid 函數(shù)來限制值域
公式推導(dǎo)
我們的限制函數(shù)是 sigmoid,于是有
那么我們可以找到一個(gè) 滿足
可以看到我們的方差因?yàn)橄拗坪瘮?shù)而變得更小了,讓分布更加的均勻

整體流程
首先對(duì)輸入做中心校準(zhǔn)
然后就是熟悉的減均值,除方差
接著是做縮放校準(zhǔn)
最后是做拉伸,偏移,得到最終結(jié)果
實(shí)驗(yàn)對(duì)比
作者在主流的網(wǎng)絡(luò)里測(cè)試了常見的Normalize模塊,并進(jìn)行對(duì)比,可以看到提升還是比較顯著的
另外也通過消融實(shí)驗(yàn)證明均值校準(zhǔn)和縮放校準(zhǔn)的有效性,另外更多實(shí)驗(yàn)可以看下原文。
代碼
作者也開放了對(duì)應(yīng)的Pytorch源碼
import torch.nn as nn
import math
import torch
import numpy as np
import torch.nn.functional as F
class RepresentativeBatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(RepresentativeBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
self.num_features = num_features
### weights for affine transformation in BatchNorm ###
if self.affine:
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.weight.data.fill_(1)
self.bias.data.fill_(0)
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
### weights for centering calibration ###
self.center_weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.center_weight.data.fill_(0)
### weights for scaling calibration ###
self.scale_weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.scale_bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.scale_weight.data.fill_(0)
self.scale_bias.data.fill_(1)
### calculate statistics ###
self.stas = nn.AdaptiveAvgPool2d((1,1))
def forward(self, input):
self._check_input_dim(input)
####### centering calibration begin #######
input += self.center_weight.view(1,self.num_features,1,1)*self.stas(input)
####### centering calibration end #######
####### BatchNorm begin #######
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
output = F.batch_norm(
input, self.running_mean, self.running_var, None, None,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
####### BatchNorm end #######
####### scaling calibration begin #######
scale_factor = torch.sigmoid(self.scale_weight*self.stas(output)+self.scale_bias)
####### scaling calibration end #######
if self.affine:
return self.weight*scale_factor*output + self.bias
else:
return scale_factor*output
其中大部分代碼跟Pytorch自己實(shí)現(xiàn)的BatchNorm類似,我們簡(jiǎn)單關(guān)注幾點(diǎn)
首先在初始化里,初始化了中心校準(zhǔn),縮放校準(zhǔn)所需的可學(xué)習(xí)參數(shù),并填充默認(rèn)值
### weights for centering calibration ###
self.center_weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.center_weight.data.fill_(0)
### weights for scaling calibration ###
self.scale_weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.scale_bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.scale_weight.data.fill_(0)
self.scale_bias.data.fill_(1)
我們經(jīng)常會(huì)把可學(xué)習(xí)參數(shù)中,權(quán)重w初始化為1,偏置b初始化為0,而這里恰恰相反,將權(quán)重則初始化為0,偏置則為1。個(gè)人推測(cè)可以參考推導(dǎo)Centering Calibration中,當(dāng)w為0時(shí),則等價(jià)于原始的BN,從而后續(xù)讓模型根據(jù)需要來去調(diào)整w。但為什么偏置設(shè)為1,筆者沒想清楚。可以參考RBN開源工程的issue1,地址在這篇文章開頭
然后是初始化我們的實(shí)例特征提取操作,這里是用一個(gè)全局池化
### calculate statistics ###
self.stas = nn.AdaptiveAvgPool2d((1,1))
在forward函數(shù)一開始,我們先做中心校準(zhǔn)操作
####### centering calibration begin #######
input += self.center_weight.view(1,self.num_features,1,1)*self.stas(input)
####### centering calibration end #######
然后是調(diào)用torch自帶的Batchnorm
...
output = F.batch_norm(
input, self.running_mean, self.running_var, None, None,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
接著做縮放校準(zhǔn)操作
####### scaling calibration begin #######
scale_factor = torch.sigmoid(self.scale_weight*self.stas(output)+self.scale_bias)
####### scaling calibration end #######
最后根據(jù)屬性 self.affine 做最后的拉伸和偏移
if self.affine:
return self.weight*scale_factor*output + self.bias
else:
return scale_factor*output
總結(jié)
作者提出了一種簡(jiǎn)單有效的方法,將BN層的mini-batch的統(tǒng)計(jì)特征和各個(gè)實(shí)例獨(dú)自的特征(Representative也就體現(xiàn)在這里)巧妙的結(jié)合起來,使得能夠更好自適應(yīng)集合里的數(shù)據(jù),最后各個(gè)實(shí)驗(yàn)也證明了其有效性。期待更多在Norm方面的工作~
歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅(jiān)持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識(shí)。( ? ?ω?? )?
有對(duì)文章相關(guān)的問題,或者想要加入交流群,歡迎添加BBuf微信:
