InstanceNorm 梯度公式推導
InstanceNorm 梯度公式推導
【GiantPandaCV導語】本文主內(nèi)容是推導 InstanceNorm 關于輸入和參數(shù)的梯度公式,同時還會結(jié)合 Pytorch 和 MXNet 里面 InstanceNorm 的代碼來分析。
InstanceNorm 與 BatchNorm 的聯(lián)系
對一個形狀為 (N, C, H, W) 的張量應用 InstanceNorm[4] 操作,其實等價于先把該張量 reshape 為 (1, N * C, H, W)的張量,然后應用 BatchNorm[5] 操作。而 gamma 和 beta 參數(shù)的每個通道所對應輸入張量的位置都是一致的。
而 InstanceNorm 與 BatchNorm 不同的地方在于:
InstanceNorm 訓練與預測階段行為一致,都是利用當前 batch 的均值和方差計算; BatchNorm 訓練階段利用當前 batch 的均值和方差,測試階段則利用訓練階段通過移動平均統(tǒng)計的均值和方差;
論文[6]中的一張示意圖,就很好的解釋了兩者的聯(lián)系:

所以 InstanceNorm 對于輸入梯度和參數(shù)求導過程與 BatchNorm 類似,下面開始進入正題。
梯度推導過程詳解
在開始推導梯度公式之前,首先約定輸入,參數(shù),輸出等符號:
輸入張量 , 形狀為(N, C, H, W),rehape 為 (1, N * C, M) 其中 M=H*W 參數(shù) ,形狀為 (1, C, 1, 1),每個通道值對應 N*M 個輸入,在計算的時候首先通過在第0維 repeat N 次再 reshape 成 (1, N*C, 1, 1); 參數(shù) ,形狀為 (1, C, 1, 1),每個通道值對應 N*M 個輸入,在計算的時候首先通過在第0維 repeat N 次再 reshape 成 (1, N*C, 1, 1);
而輸入張量 reshape 成 (1, N * C, M)之后,每個通道上是一個長度為 M 的向量,這些向量之間的計算是不像干的,每個向量計算自己的 normalize 結(jié)果。所以求導也是各自獨立。因此下面的均值、方差符號約定和求導也只關注于其中一個向量,其他通道上的向量計算都是一樣的。
一個向量上的均值 一個向量上的方差 一個向量上一個點的 normalize 中間輸出 一個向量上一個點的 normalize 最終輸出 ,其中 和 表示這個向量所對應的 gamma 和 beta 參數(shù)的通道值。 loss 函數(shù)的符號約定為
gamma 和 beta 參數(shù)梯度的推導
先計算簡單的部分,求 loss 對 和 的偏導:
其中 表示 gamma 和 beta 參數(shù)的第 個通道參與了哪些 batch 上向量的 normalize 計算。
因為 gamma 和 beta 上的每個通道的參數(shù)都參數(shù)與了 N 個 batch 上 M 個元素 normalize 的計算,所以對每個通道進行求導的時候,需要把所有涉及到的位置的梯度都累加在一起。
對于 在具體實現(xiàn)的時候,就是對應輸出梯度的值,也就是從上一層回傳回來的梯度值。
輸入梯度的推導
對輸入梯度的求導是最復雜的,下面的推導都是求 loss 相對于輸入張量上的一個點上的梯度,而因為上文已知,每個長度是 M 的向量的計算都是獨立的,所以下文也是描述其中一個向量上一個點的梯度公式。具體是計算的時候,是通過向量操作(比如 numpy)來完成所有點的梯度計算。
先看 loss 函數(shù)對于 的求導:
而從上文約定的公式可知,對于 的計算中涉及到 的有三部分,分別是 、 和 。所以 loss 對于 的偏導可以寫成以下的形式:
接下來就是,分別求上面式子最后三項的梯度公式
第一項梯度推導
在求第一項的時候,把 和 看做常量,則有:
然后有:
最后可得第一項梯度公式:
第三項梯度推導
接著先看第三項梯度 ,因為第三項的推導形式簡單一些。
先計算上式最后一項 ,把 看做常量:
然后計算 ,等價于求 。而因為每個長度是 M 的向量都會計算一個方差 ,而計算出來的方差又會參數(shù)到所有 M 個元素的 normalize 的計算,所以 loss 對于 的偏導需要把所有 M 個位置的梯度累加,所以有:
接著計算 ,
最后可得:
第二項梯度推導
最后計算第二項的梯度 ,一樣先計算最后一項 :
接著計算 ,等價于是求 。而因為每個長度是 M 的向量都會計算一個均值 ,而計算出來的均值又會參與到所有 M 個元素的 normalize 的計算,所以 loss 對于 的偏導需要把所有 M 個位置的梯度累加,所以有:
接著計算 ,
最后可得:
輸入梯度最終的公式
分別計算完上面三項,就能得到對于輸入張量每個位置上梯度的最終公式了:
觀察上式可以發(fā)現(xiàn),loss 對 的求導公式包括了 loss 對 求導的公式,所以這也是為什么先計算第三項的原因,在下面代碼實現(xiàn)上也可以體現(xiàn)。
而在具體實現(xiàn)的時候就是直接套公式計算就可以了,下面來看下在 Pytroch 和 MXNet 框架中對 InstanceNorm 的實現(xiàn)。
主流框架實現(xiàn)代碼解讀
Pytroch 前向傳播實現(xiàn)
前向代碼鏈接:https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L506
為了可讀性簡化了些代碼:
Tensor?instance_norm(
????const?Tensor&?input,?
????const?Tensor&?weight/*?optional?*/,?
????const?Tensor&?bias/*?optional?*/,
????const?Tensor&?running_mean/*?optional?*/,?
????const?Tensor&?running_var/*?optional?*/,
????bool?use_input_stats,?
????double?momentum,?
????double?eps,?
????bool?cudnn_enabled)?{
??//?......
??std::vector<int64_t>?shape?=?
????input.sizes().vec();
??int64_t?b?=?input.size(0);
??int64_t?c?=?input.size(1);
??//?shape?從?(b,?c,?h,?w)
??//?變?yōu)?(1,?b*c,?h,?w)
??shape[1]?=?b?*?c;
??shape[0]?=?1;
??//?repeat_if_defined?的解釋見下文
??Tensor?weight_?=?
??????repeat_if_defined(weight,?b);
??Tensor?bias_?=?
??????repeat_if_defined(bias,?b);
??Tensor?running_mean_?=?
??????repeat_if_defined(running_mean,?b);
??Tensor?running_var_?=?
??????repeat_if_defined(running_var,?b);
??//?改變輸入張量的形狀
??auto?input_reshaped?=?
??????input.contiguous().view(shape);
??//?計算實際調(diào)用的是?batchnorm?的實現(xiàn)
??//?所以可以理解為什么?pytroch?
??//?前端?InstanceNorm2d?的接口
??//?與?BatchNorm2d?的接口一樣
??auto?out?=?at::batch_norm(
????input_reshaped,?
????weight_,?bias_,?
????running_mean_,?
????running_var_,
????use_input_stats,?
????momentum,
????eps,?cudnn_enabled);
??//?......
??return?out.view(input.sizes());
}
repeat_if_defined 的代碼:
https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L27
static?inline?Tensor?repeat_if_defined(
??const?Tensor&?t,?
??int64_t?repeat)?{
??if?(t.defined())?{
????//?把?tensor?按第0維度復制?repeat?次
????return?t.repeat(repeat);
??}
??return?t;
}
從 pytorch 前向傳播的實現(xiàn)上看,驗證了本文開頭說的關于 InstanceNorm 與 BatchNorm 的聯(lián)系。還有對于參數(shù) gamma 與 beta 的處理方式。
MXNet 反向傳播實現(xiàn)
因為我個人感覺 MXNet InstanceNorm 的反向傳播實現(xiàn)很直觀,所以選擇解讀其實現(xiàn):
https://github.com/apache/incubator-mxnet/blob/4a7282f104590023d846f505527fd0d490b65509/src%2Foperator%2Finstance_norm-inl.h#L112
同樣為了可讀性簡化了些代碼:
template<typename?xpu>
void?InstanceNormBackward(
????const?nnvm::NodeAttrs&?attrs,
????const?OpContext?&ctx,
????const?std::vector?&inputs,
????const?std::vector?&req,
????const?std::vector?&outputs) ?{
??using?namespace?mshadow;
??using?namespace?mshadow::expr;
??//?......
??const?InstanceNormParam&?param?=?
??????nnvm::get(
????????attrs.parsed);
??Stream?*s?=?
??????ctx.get_stream();
??//?獲取輸入張量的形狀
??mxnet::TShape?dshape?=?
??????inputs[3].shape_;
??//?......
??int?n?=?inputs[3].size(0);
??int?c?=?inputs[3].size(1);
??//?rest_dim?就等于上文的?M
??int?rest_dim?=
??????static_cast<int>(
????????inputs[3].Size()?/?n?/?c);
??Shape<2>?s2?=?Shape2(n?*?c,?rest_dim);
??Shape<3>?s3?=?Shape3(n,?c,?rest_dim);
??//?scale?就等于上文的?1/M
??const?real_t?scale?=?
??????static_cast<real_t>(1)?/?
??????????static_cast<real_t>(rest_dim);
??//?獲取輸入張量
??Tensor2>?data?=?inputs[3]
???.get_with_shape2,?real_t>(s2,?s);
??//?保存輸入梯度
??Tensor2>?gdata?=?outputs[kData]
???.get_with_shape2,?real_t>(s2,?s);
??//?獲取參數(shù)?gamma?
??Tensor1>?gamma?=
??????inputs[4].get1,?real_t>(s);
??//?保存參數(shù)?gamma?梯度計算結(jié)果
??Tensor1>?ggamma?=?outputs[kGamma]
??????.get1,?real_t>(s);
??//?保存參數(shù)?beta?梯度計算結(jié)果
??Tensor1>?gbeta?=?outputs[kBeta]
??????.get1,?real_t>(s);
??//?獲取輸出梯度
??Tensor2>?gout?=?inputs[0]
??????.get_with_shape2,?real_t>(
????????s2,?s);
??//?獲取前向計算好的均值和方差
??Tensor1>?var?=?
????inputs[2].FlatTo1Dreal_t>(s);
??Tensor1>?mean?=?
????inputs[1].FlatTo1Dreal_t>(s);
??//?臨時空間
??Tensor2>?workspace?=?//.....
??//?保存均值的梯度
??Tensor1>?gmean?=?workspace[0];
??//?保存方差的梯度
??Tensor1>?gvar?=?workspace[1];
??Tensor1>?tmp?=?workspace[2];
??//?計算方差的梯度,
??//?對應上文輸入梯度公式的第三項
??//?gout?對應輸出梯度
??gvar?=?sumall_except_dim<0>(
????(gout?*?broadcast<0>(
??????reshape(repmat(gamma,?n),?
????????Shape1(n?*?c)),?data.shape_))?*
??????(data?-?broadcast<0>(
????????mean,?data.shape_))?*?-0.5f?*
??????F(
????????broadcast<0>(
??????????var?+?param.eps,?data.shape_),?
??????-1.5f)
????);
??//?計算均值的梯度,
??//?對應上文輸入梯度公式的第二項
??gmean?=?sumall_except_dim<0>(
????gout?*?broadcast<0>(
??????reshape(repmat(gamma,?n),?
????????Shape1(n?*?c)),?data.shape_));
??gmean?*=?
????-1.0f?/?F(
??????var?+?param.eps);
??tmp?=?scale?*?sumall_except_dim<0>(
??????????-2.0f?*?(data?-?broadcast<0>(
????????????mean,?data.shape_)));
??tmp?*=?gvar;
??gmean?+=?tmp;
??//?計算?beta?的梯度
??//?記得s3?=?Shape3(n,?c,?rest_dim)
??//?那么swapaxis<1,?0>(reshape(gout,?s3))
??//?就表示首先把輸出梯度?reshape?成
??//?(n,?c,?rest_dim),接著交換第0和1維度
??//?(c,?n,?rest_dim),最后求除了第0維度
??//?之外其他維度的和,
??//?也就和?beta?的求導公式對應上了
??Assign(gbeta,?req[kBeta],
????sumall_except_dim<0>(
???????swapaxis<1,?0>(reshape(gout,?s3))));
???????
??//?計算?gamma?的梯度
??//?swapaxis<1,?0>?的作用與上面?beta?一樣
??Assign(ggamma,?req[kGamma],
????sumall_except_dim<0>(
??????swapaxis<1,?0>(
????????reshape(gout?*?
?????????(data?-?broadcast<0>(mean,?
???????????data.shape_))?
???????????/?F(
???????????????broadcast<0>(
????????????????var?+?param.eps,
??????????????????data.shape_
???????????????)
?????????????),?s3
????????)
??????)
????)
??);
??//?計算輸入的梯度,
??//?對應上文輸入梯度公式三項的相加
??Assign(gdata,?req[kData],
????(gout?*?broadcast<0>(
??????reshape(repmat(gamma,?n),?
????????Shape1(n?*?c)),?data.shape_))
??????*?broadcast<0>(1.0f?/?
????????F(
??????????var?+?param.eps),?data.shape_)?
????????????????
????+?broadcast<0>(gvar,?data.shape_)?
??????*?scale?*?2.0f?
??????*?(data?-?broadcast<0>(
????????mean,?data.shape_))?
????
????+?broadcast<0>(gmean,?
??????data.shape_)?*?scale);
}
可以看到基于 mshadow 模板庫的反向傳播實現(xiàn),看起來很直觀,基本是和公式能對應上的。
InstanceNorm numpy 實現(xiàn)
最后看下 InstanceNorm 前向計算與求輸入梯度的 numpy 實現(xiàn)
import?numpy?as?np
import?torch
eps?=?1e-05
batch?=?4
channel?=?2
height?=?32
width?=?32
input?=?np.random.random(
????size=(batch,?channel,?height,?width)).astype(np.float32)
#?gamma?初始化為1
#?beta?初始化為0,所以忽略了
gamma?=?np.ones((1,?channel,?1,?1),?
????dtype=np.float32)
#?隨機生成輸出梯度
gout?=?np.random.random(
????size=(batch,?channel,?height,?width))\
????.astype(np.float32)
#?用numpy計算前向的結(jié)果
mean_np?=?np.mean(
??input,?axis=(2,?3),?keepdims=True)
in_sub_mean?=?input?-?mean_np
var_np?=?np.mean(
????np.square(in_sub_mean),?
??????axis=(2,?3),?keepdims=True)
invar_np?=?1.0?/?np.sqrt(var_np?+?eps)
out_np?=?in_sub_mean?*?invar_np?*?gamma
#?用numpy計算輸入梯度
scale?=?1.0?/?(height?*?width)
#?對應輸入梯度公式第三項
gvar?=?
??gout?*?gamma?*?in_sub_mean?*
???-0.5?*?np.power(var_np?+?eps,?-1.5)
gvar?=?np.sum(gvar,?axis=(2,?3),?
????????keepdims=True)
#?對應輸入梯度公式第二項
gmean?=?np.sum(
????gout?*?gamma,?
????axis=(2,?3),?keepdims=True)
gmean?*=?-invar_np
tmp?=?scale?*?np.sum(-2.0?*?in_sub_mean,?
????????axis=(2,?3),?keepdims=True)?
gmean?+=?tmp?*?gvar
#?對應輸入梯度公式三項之和
gin_np?=?
??gout?*?gamma?*?invar_np
????+?gvar?*?scale?*?2.0?*?in_sub_mean
????+?gmean?*?scale
#?pytorch?的實現(xiàn)
p_input_tensor?=?
??torch.tensor(input,?requires_grad=True)
trans?=?torch.nn.InstanceNorm2d(
??channel,?affine=True,?eps=eps)
p_output_tensor?=?trans(p_input_tensor)
p_output_tensor.backward(
??torch.Tensor(gout))
#?與?pytorch?對比結(jié)果
print(np.allclose(out_np,?
??p_output_tensor.detach().numpy(),?
??atol=1e-5))
print(np.allclose(gin_np,?
??p_input_tensor.grad.numpy(),?
??atol=1e-5))
#?命令行輸出
#?True
#?True
總結(jié)
本文對于 InstanceNorm 的梯度公式推導大部分參考了博客[1][2]的內(nèi)容,然后在參考博客的基礎上,按自己的理解具體推導了一遍,很多時候是從結(jié)果往回推,在推導過程中會有不太嚴謹?shù)牡胤?,如果有什么疑惑或意見,歡迎交流。
參考資料:
[1] https://medium.com/@drsealks/batch-normalisation-formulas-derivation-253df5b75220 [2] https://kevinzakka.github.io/2016/09/14/batch_normalization/ [3] https://www.zhihu.com/question/68730628 [4] https://arxiv.org/pdf/1607.08022.pdf [5] https://arxiv.org/pdf/1502.03167v3.pdf [6] https://arxiv.org/pdf/1803.08494.pdf
