<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          InstanceNorm 梯度公式推導

          共 9304字,需瀏覽 19分鐘

           ·

          2020-12-24 08:53

          InstanceNorm 梯度公式推導

          【GiantPandaCV導語】本文主內(nèi)容是推導 InstanceNorm 關于輸入和參數(shù)的梯度公式,同時還會結(jié)合 Pytorch 和 MXNet 里面 InstanceNorm 的代碼來分析。

          InstanceNorm 與 BatchNorm 的聯(lián)系

          對一個形狀為 (N, C, H, W) 的張量應用 InstanceNorm[4] 操作,其實等價于先把該張量 reshape 為 (1, N * C, H, W)的張量,然后應用 BatchNorm[5] 操作。而 gamma 和 beta 參數(shù)的每個通道所對應輸入張量的位置都是一致的。

          而 InstanceNorm 與 BatchNorm 不同的地方在于:

          • InstanceNorm 訓練與預測階段行為一致,都是利用當前 batch 的均值和方差計算;
          • BatchNorm 訓練階段利用當前 batch 的均值和方差,測試階段則利用訓練階段通過移動平均統(tǒng)計的均值和方差;

          論文[6]中的一張示意圖,就很好的解釋了兩者的聯(lián)系:

          https://arxiv.org/pdf/1803.08494.pdf

          所以 InstanceNorm 對于輸入梯度和參數(shù)求導過程與 BatchNorm 類似,下面開始進入正題。

          梯度推導過程詳解

          在開始推導梯度公式之前,首先約定輸入,參數(shù),輸出等符號:

          • 輸入張量 , 形狀為(N, C, H, W),rehape 為 (1, N * C, M) 其中 M=H*W
          • 參數(shù) ,形狀為 (1, C, 1, 1),每個通道值對應 N*M 個輸入,在計算的時候首先通過在第0維 repeat N 次再 reshape 成 (1, N*C, 1, 1);
          • 參數(shù) ,形狀為 (1, C, 1, 1),每個通道值對應 N*M 個輸入,在計算的時候首先通過在第0維 repeat N 次再 reshape 成 (1, N*C, 1, 1);

          而輸入張量 reshape 成 (1, N * C, M)之后,每個通道上是一個長度為 M 的向量,這些向量之間的計算是不像干的,每個向量計算自己的 normalize 結(jié)果。所以求導也是各自獨立。因此下面的均值、方差符號約定和求導也只關注于其中一個向量,其他通道上的向量計算都是一樣的。

          • 一個向量上的均值
          • 一個向量上的方差
          • 一個向量上一個點的 normalize 中間輸出
          • 一個向量上一個點的 normalize 最終輸出 ,其中 表示這個向量所對應的 gamma 和 beta 參數(shù)的通道值。
          • loss 函數(shù)的符號約定為

          gamma 和 beta 參數(shù)梯度的推導

          先計算簡單的部分,求 loss 對 的偏導:



          其中 表示 gamma 和 beta 參數(shù)的第 個通道參與了哪些 batch 上向量的 normalize 計算。

          因為 gamma 和 beta 上的每個通道的參數(shù)都參數(shù)與了 N 個 batch 上 M 個元素 normalize 的計算,所以對每個通道進行求導的時候,需要把所有涉及到的位置的梯度都累加在一起。

          對于 在具體實現(xiàn)的時候,就是對應輸出梯度的值,也就是從上一層回傳回來的梯度值。

          輸入梯度的推導

          對輸入梯度的求導是最復雜的,下面的推導都是求 loss 相對于輸入張量上的一個點上的梯度,而因為上文已知,每個長度是 M 的向量的計算都是獨立的,所以下文也是描述其中一個向量上一個點的梯度公式。具體是計算的時候,是通過向量操作(比如 numpy)來完成所有點的梯度計算。

          先看 loss 函數(shù)對于 的求導:


          而從上文約定的公式可知,對于 的計算中涉及到 的有三部分,分別是 、。所以 loss 對于 的偏導可以寫成以下的形式:


          接下來就是,分別求上面式子最后三項的梯度公式

          第一項梯度推導

          在求第一項的時候,把 看做常量,則有:


          然后有:


          最后可得第一項梯度公式:


          第三項梯度推導

          接著先看第三項梯度 ,因為第三項的推導形式簡單一些。

          先計算上式最后一項 ,把 看做常量:


          然后計算 ,等價于求 。而因為每個長度是 M 的向量都會計算一個方差 ,而計算出來的方差又會參數(shù)到所有 M 個元素的 normalize 的計算,所以 loss 對于 的偏導需要把所有 M 個位置的梯度累加,所以有:


          接著計算 ,


          最后可得:



          第二項梯度推導

          最后計算第二項的梯度 ,一樣先計算最后一項


          接著計算 ,等價于是求 。而因為每個長度是 M 的向量都會計算一個均值 ,而計算出來的均值又會參與到所有 M 個元素的 normalize 的計算,所以 loss 對于 的偏導需要把所有 M 個位置的梯度累加,所以有:


          接著計算 ,


          最后可得:



          輸入梯度最終的公式

          分別計算完上面三項,就能得到對于輸入張量每個位置上梯度的最終公式了:


          觀察上式可以發(fā)現(xiàn),loss 對 的求導公式包括了 loss 對 求導的公式,所以這也是為什么先計算第三項的原因,在下面代碼實現(xiàn)上也可以體現(xiàn)。

          而在具體實現(xiàn)的時候就是直接套公式計算就可以了,下面來看下在 Pytroch 和 MXNet 框架中對 InstanceNorm 的實現(xiàn)。

          主流框架實現(xiàn)代碼解讀

          Pytroch 前向傳播實現(xiàn)

          前向代碼鏈接:https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L506

          為了可讀性簡化了些代碼:

          Tensor?instance_norm(
          ????const?Tensor&?input,?
          ????const?Tensor&?weight/*?optional?*/,?
          ????const?Tensor&?bias/*?optional?*/,
          ????const?Tensor&?running_mean/*?optional?*/,?
          ????const?Tensor&?running_var/*?optional?*/,
          ????bool?use_input_stats,?
          ????double?momentum,?
          ????double?eps,?
          ????bool?cudnn_enabled)
          ?
          {
          ??//?......
          ??std::vector<int64_t>?shape?=?
          ????input.sizes().vec();
          ??int64_t?b?=?input.size(0);
          ??int64_t?c?=?input.size(1);
          ??//?shape?從?(b,?c,?h,?w)
          ??//?變?yōu)?(1,?b*c,?h,?w)
          ??shape[1]?=?b?*?c;
          ??shape[0]?=?1;
          ??//?repeat_if_defined?的解釋見下文
          ??Tensor?weight_?=?
          ??????repeat_if_defined(weight,?b);
          ??Tensor?bias_?=?
          ??????repeat_if_defined(bias,?b);
          ??Tensor?running_mean_?=?
          ??????repeat_if_defined(running_mean,?b);
          ??Tensor?running_var_?=?
          ??????repeat_if_defined(running_var,?b);
          ??//?改變輸入張量的形狀
          ??auto?input_reshaped?=?
          ??????input.contiguous().view(shape);
          ??//?計算實際調(diào)用的是?batchnorm?的實現(xiàn)
          ??//?所以可以理解為什么?pytroch?
          ??//?前端?InstanceNorm2d?的接口
          ??//?與?BatchNorm2d?的接口一樣
          ??auto?out?=?at::batch_norm(
          ????input_reshaped,?
          ????weight_,?bias_,?
          ????running_mean_,?
          ????running_var_,
          ????use_input_stats,?
          ????momentum,
          ????eps,?cudnn_enabled);
          ??//?......
          ??return?out.view(input.sizes());
          }

          repeat_if_defined 的代碼:

          https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L27

          static?inline?Tensor?repeat_if_defined(
          ??const?Tensor&?t,?
          ??int64_t?repeat)
          ?
          {
          ??if?(t.defined())?{
          ????//?把?tensor?按第0維度復制?repeat?次
          ????return?t.repeat(repeat);
          ??}
          ??return?t;
          }

          從 pytorch 前向傳播的實現(xiàn)上看,驗證了本文開頭說的關于 InstanceNorm 與 BatchNorm 的聯(lián)系。還有對于參數(shù) gamma 與 beta 的處理方式。

          MXNet 反向傳播實現(xiàn)

          因為我個人感覺 MXNet InstanceNorm 的反向傳播實現(xiàn)很直觀,所以選擇解讀其實現(xiàn):

          https://github.com/apache/incubator-mxnet/blob/4a7282f104590023d846f505527fd0d490b65509/src%2Foperator%2Finstance_norm-inl.h#L112

          同樣為了可讀性簡化了些代碼:

          template<typename?xpu>
          void?InstanceNormBackward(
          ????const?nnvm::NodeAttrs&?attrs,
          ????const?OpContext?&ctx,
          ????const?std::vector?&inputs,
          ????const?std::vector?&req,
          ????const?std::vector?&outputs)
          ?
          {
          ??using?namespace?mshadow;
          ??using?namespace?mshadow::expr;
          ??//?......
          ??const?InstanceNormParam&?param?=?
          ??????nnvm::get(
          ????????attrs.parsed);

          ??Stream?*s?=?
          ??????ctx.get_stream();
          ??//?獲取輸入張量的形狀
          ??mxnet::TShape?dshape?=?
          ??????inputs[3].shape_;
          ??//?......
          ??int?n?=?inputs[3].size(0);
          ??int?c?=?inputs[3].size(1);
          ??//?rest_dim?就等于上文的?M
          ??int?rest_dim?=
          ??????static_cast<int>(
          ????????inputs[3].Size()?/?n?/?c);
          ??Shape<2>?s2?=?Shape2(n?*?c,?rest_dim);
          ??Shape<3>?s3?=?Shape3(n,?c,?rest_dim);
          ??//?scale?就等于上文的?1/M
          ??const?real_t?scale?=?
          ??????static_cast<real_t>(1)?/?
          ??????????static_cast<real_t>(rest_dim);
          ??//?獲取輸入張量
          ??Tensor2>?data?=?inputs[3]
          ???.get_with_shape2,?real_t>(s2,?s);
          ??//?保存輸入梯度
          ??Tensor2>?gdata?=?outputs[kData]
          ???.get_with_shape2,?real_t>(s2,?s);
          ??//?獲取參數(shù)?gamma?
          ??Tensor1>?gamma?=
          ??????inputs[4].get1,?real_t>(s);
          ??//?保存參數(shù)?gamma?梯度計算結(jié)果
          ??Tensor1>?ggamma?=?outputs[kGamma]
          ??????.get1,?real_t>(s);
          ??//?保存參數(shù)?beta?梯度計算結(jié)果
          ??Tensor1>?gbeta?=?outputs[kBeta]
          ??????.get1,?real_t>(s);
          ??//?獲取輸出梯度
          ??Tensor2>?gout?=?inputs[0]
          ??????.get_with_shape2,?real_t>(
          ????????s2,?s);
          ??//?獲取前向計算好的均值和方差
          ??Tensor1>?var?=?
          ????inputs[2].FlatTo1Dreal_t>(s);
          ??Tensor1>?mean?=?
          ????inputs[1].FlatTo1Dreal_t>(s);
          ??//?臨時空間
          ??Tensor2>?workspace?=?//.....
          ??//?保存均值的梯度
          ??Tensor1>?gmean?=?workspace[0];
          ??//?保存方差的梯度
          ??Tensor1>?gvar?=?workspace[1];
          ??Tensor1>?tmp?=?workspace[2];

          ??//?計算方差的梯度,
          ??//?對應上文輸入梯度公式的第三項
          ??//?gout?對應輸出梯度
          ??gvar?=?sumall_except_dim<0>(
          ????(gout?*?broadcast<0>(
          ??????reshape(repmat(gamma,?n),?
          ????????Shape1(n?*?c)),?data.shape_))?*
          ??????(data?-?broadcast<0>(
          ????????mean,?data.shape_))?*?-0.5f?*
          ??????F(
          ????????broadcast<0>(
          ??????????var?+?param.eps,?data.shape_),?
          ??????-1.5f)
          ????);
          ??//?計算均值的梯度,
          ??//?對應上文輸入梯度公式的第二項
          ??gmean?=?sumall_except_dim<0>(
          ????gout?*?broadcast<0>(
          ??????reshape(repmat(gamma,?n),?
          ????????Shape1(n?*?c)),?data.shape_));
          ??gmean?*=?
          ????-1.0f?/?F(
          ??????var?+?param.eps);
          ??tmp?=?scale?*?sumall_except_dim<0>(
          ??????????-2.0f?*?(data?-?broadcast<0>(
          ????????????mean,?data.shape_)));
          ??tmp?*=?gvar;
          ??gmean?+=?tmp;

          ??//?計算?beta?的梯度
          ??//?記得s3?=?Shape3(n,?c,?rest_dim)
          ??//?那么swapaxis<1,?0>(reshape(gout,?s3))
          ??//?就表示首先把輸出梯度?reshape?成
          ??//?(n,?c,?rest_dim),接著交換第0和1維度
          ??//?(c,?n,?rest_dim),最后求除了第0維度
          ??//?之外其他維度的和,
          ??//?也就和?beta?的求導公式對應上了
          ??Assign(gbeta,?req[kBeta],
          ????sumall_except_dim<0>(
          ???????swapaxis<1,?0>(reshape(gout,?s3))));
          ???????
          ??//?計算?gamma?的梯度
          ??//?swapaxis<1,?0>?的作用與上面?beta?一樣
          ??Assign(ggamma,?req[kGamma],
          ????sumall_except_dim<0>(
          ??????swapaxis<1,?0>(
          ????????reshape(gout?*?
          ?????????(data?-?broadcast<0>(mean,?
          ???????????data.shape_))?
          ???????????/?F(
          ???????????????broadcast<0>(
          ????????????????var?+?param.eps,
          ??????????????????data.shape_
          ???????????????)
          ?????????????),?s3
          ????????)
          ??????)
          ????)
          ??);
          ??//?計算輸入的梯度,
          ??//?對應上文輸入梯度公式三項的相加
          ??Assign(gdata,?req[kData],
          ????(gout?*?broadcast<0>(
          ??????reshape(repmat(gamma,?n),?
          ????????Shape1(n?*?c)),?data.shape_))
          ??????*?broadcast<0>(1.0f?/?
          ????????F(
          ??????????var?+?param.eps),?data.shape_)?
          ????????????????
          ????+?broadcast<0>(gvar,?data.shape_)?
          ??????*?scale?*?2.0f?
          ??????*?(data?-?broadcast<0>(
          ????????mean,?data.shape_))?
          ????
          ????+?broadcast<0>(gmean,?
          ??????data.shape_)?*?scale);
          }

          可以看到基于 mshadow 模板庫的反向傳播實現(xiàn),看起來很直觀,基本是和公式能對應上的。

          InstanceNorm numpy 實現(xiàn)

          最后看下 InstanceNorm 前向計算與求輸入梯度的 numpy 實現(xiàn)

          import?numpy?as?np
          import?torch

          eps?=?1e-05
          batch?=?4
          channel?=?2
          height?=?32
          width?=?32

          input?=?np.random.random(
          ????size=(batch,?channel,?height,?width)).astype(np.float32)
          #?gamma?初始化為1
          #?beta?初始化為0,所以忽略了
          gamma?=?np.ones((1,?channel,?1,?1),?
          ????dtype=np.float32)
          #?隨機生成輸出梯度
          gout?=?np.random.random(
          ????size=(batch,?channel,?height,?width))\
          ????.astype(np.float32)

          #?用numpy計算前向的結(jié)果
          mean_np?=?np.mean(
          ??input,?axis=(2,?3),?keepdims=True)
          in_sub_mean?=?input?-?mean_np
          var_np?=?np.mean(
          ????np.square(in_sub_mean),?
          ??????axis=(2,?3),?keepdims=True)
          invar_np?=?1.0?/?np.sqrt(var_np?+?eps)
          out_np?=?in_sub_mean?*?invar_np?*?gamma

          #?用numpy計算輸入梯度
          scale?=?1.0?/?(height?*?width)
          #?對應輸入梯度公式第三項
          gvar?=?
          ??gout?*?gamma?*?in_sub_mean?*
          ???-0.5?*?np.power(var_np?+?eps,?-1.5)
          gvar?=?np.sum(gvar,?axis=(2,?3),?
          ????????keepdims=True)

          #?對應輸入梯度公式第二項
          gmean?=?np.sum(
          ????gout?*?gamma,?
          ????axis=(2,?3),?keepdims=True)
          gmean?*=?-invar_np
          tmp?=?scale?*?np.sum(-2.0?*?in_sub_mean,?
          ????????axis=(2,?3),?keepdims=True)?
          gmean?+=?tmp?*?gvar

          #?對應輸入梯度公式三項之和
          gin_np?=?
          ??gout?*?gamma?*?invar_np
          ????+?gvar?*?scale?*?2.0?*?in_sub_mean
          ????+?gmean?*?scale


          #?pytorch?的實現(xiàn)
          p_input_tensor?=?
          ??torch.tensor(input,?requires_grad=True)
          trans?=?torch.nn.InstanceNorm2d(
          ??channel,?affine=True,?eps=eps)
          p_output_tensor?=?trans(p_input_tensor)
          p_output_tensor.backward(
          ??torch.Tensor(gout))

          #?與?pytorch?對比結(jié)果
          print(np.allclose(out_np,?
          ??p_output_tensor.detach().numpy(),?
          ??atol=1e-5))
          print(np.allclose(gin_np,?
          ??p_input_tensor.grad.numpy(),?
          ??atol=1e-5))

          #?命令行輸出
          #?True
          #?True

          總結(jié)

          本文對于 InstanceNorm 的梯度公式推導大部分參考了博客[1][2]的內(nèi)容,然后在參考博客的基礎上,按自己的理解具體推導了一遍,很多時候是從結(jié)果往回推,在推導過程中會有不太嚴謹?shù)牡胤?,如果有什么疑惑或意見,歡迎交流。

          參考資料:

          • [1] https://medium.com/@drsealks/batch-normalisation-formulas-derivation-253df5b75220
          • [2] https://kevinzakka.github.io/2016/09/14/batch_normalization/
          • [3] https://www.zhihu.com/question/68730628
          • [4] https://arxiv.org/pdf/1607.08022.pdf
          • [5] https://arxiv.org/pdf/1502.03167v3.pdf
          • [6] https://arxiv.org/pdf/1803.08494.pdf
          瀏覽 37
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  黄色电影大香蕉 | 99热在线观看免费精品 | 婷婷激情丁香五月天 | 无码三级电影 | 人妻夜夜夜夜夜夜 |