<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          重參系列 | 輕量化模型+重參技術(shù)是不是可以起飛?

          共 22445字,需瀏覽 45分鐘

           ·

          2022-10-25 00:04

          1、開篇小記

          1.1、知識(shí)點(diǎn)1

          相同的架構(gòu),不同激活函數(shù)帶來的延遲差異極大。這里Mobileone選擇的是使用ReLU。

          1.2、知識(shí)點(diǎn)2

          當(dāng)采用單分支結(jié)構(gòu)時(shí),模型具有更快的速度。這個(gè)在RepVGG時(shí)就已經(jīng)知道了,這里讀者可以參考RepVGG筆記。

          2、MobileOne 簡(jiǎn)述

          MobileOne 的核心模塊基于 MobileNetV1 而設(shè)計(jì),同時(shí)吸收了重參數(shù)思想,得到上圖所示的結(jié)構(gòu)。注:這里的重參數(shù)機(jī)制還存在一個(gè)超參k用于控制重參數(shù)分支的數(shù)量(實(shí)驗(yàn)表明:對(duì)于小模型來說,該變種收益更大)。

          通過上圖,如果你愿意,其實(shí)就是DBB+RepVGG的結(jié)合,而分支數(shù)你可以隨意的擴(kuò)寬,重參的化直接進(jìn)行weight與bias的合并即可。

          3、MobileOne 的實(shí)現(xiàn)

          以下是 MobileOne 的Pytorch實(shí)現(xiàn):

          from typing import Optional, List, Tuple

          import copy
          import torch
          import torch.nn as nn
          import torch.nn.functional as F


          class MobileOneBlock(nn.Module):
              def __init__(self,
                           in_channels: int,
                           out_channels: int,
                           kernel_size: int,
                           stride: int = 1,
                           padding: int = 0,
                           dilation: int = 1,
                           groups: int = 1,
                           inference_mode: bool = False,
                           use_se: bool = False,
                           num_conv_branches: int = 3)
           -> None:

                  super(MobileOneBlock, self).__init__()
                  self.inference_mode = inference_mode
                  self.groups = groups
                  self.stride = stride
                  self.kernel_size = kernel_size
                  self.in_channels = in_channels
                  self.out_channels = out_channels
                  self.num_conv_branches = num_conv_branches

                  # Check if SE-ReLU is requested
                  self.se = nn.Identity()
                  self.activation = nn.ReLU()

                  if inference_mode:
                      self.reparam_conv = nn.Conv2d(in_channels=in_channels,
                                                    out_channels=out_channels,
                                                    kernel_size=kernel_size,
                                                    stride=stride,
                                                    padding=padding,
                                                    dilation=dilation,
                                                    groups=groups,
                                                    bias=True)
                  else:
                      # skip connection
                      self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None

                      # 3x3 conv branches
                      rbr_conv = list()
                      for _ in range(self.num_conv_branches):
                          rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding))
                      self.rbr_conv = nn.ModuleList(rbr_conv)

                      # 1x1 conv branch(scale branch)
                      self.rbr_scale = None
                      if kernel_size > 1:
                          self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)

              def forward(self, x: torch.Tensor) -> torch.Tensor:
                  """ Apply forward pass. """
                  if self.inference_mode:
                      return self.activation(self.se(self.reparam_conv(x)))

                  identity_out = 0
                  if self.rbr_skip is not None:
                      identity_out = self.rbr_skip(x)

                  # Scale branch output
                  scale_out = 0
                  if self.rbr_scale is not None:
                      scale_out = self.rbr_scale(x)

                  # Other branches
                  out = scale_out + identity_out
                  for ix in range(self.num_conv_branches):
                      out += self.rbr_conv[ix](x)

                  return self.activation(self.se(out))

              def reparameterize(self):
                  if self.inference_mode:
                      return
                  kernel, bias = self._get_kernel_bias()
                  self.reparam_conv = nn.Conv2d(in_channels=self.rbr_conv[0].conv.in_channels,
                                                out_channels=self.rbr_conv[0].conv.out_channels,
                                                kernel_size=self.rbr_conv[0].conv.kernel_size,
                                                stride=self.rbr_conv[0].conv.stride,
                                                padding=self.rbr_conv[0].conv.padding,
                                                dilation=self.rbr_conv[0].conv.dilation,
                                                groups=self.rbr_conv[0].conv.groups,
                                                bias=True)
                  self.reparam_conv.weight.data = kernel
                  self.reparam_conv.bias.data = bias

                  for para in self.parameters():
                      para.detach_()
                  self.__delattr__('rbr_conv')
                  self.__delattr__('rbr_scale')
                  if hasattr(self, 'rbr_skip'):
                      self.__delattr__('rbr_skip')

                  self.inference_mode = True

              def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
                  # 獲取scale分支的卷積核bias
                  kernel_scale = 0
                  bias_scale = 0
                  if self.rbr_scale is not None:
                      kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
                      # 將scale分支Pad為卷積分支
                      pad = self.kernel_size // 2
                      kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])

                  # 獲取 skip 分支的權(quán)重
                  kernel_identity = 0
                  bias_identity = 0
                  if self.rbr_skip is not None:
                      kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)

                  # 獲取卷積分支的權(quán)重
                  kernel_conv = 0
                  bias_conv = 0
                  for ix in range(self.num_conv_branches):
                      _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
                      kernel_conv += _kernel
                      bias_conv += _bias

                  kernel_final = kernel_conv + kernel_scale + kernel_identity
                  bias_final = bias_conv + bias_scale + bias_identity
                  return kernel_final, bias_final

              def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
                  if isinstance(branch, nn.Sequential):
                      kernel = branch.conv.weight
                      running_mean = branch.bn.running_mean
                      running_var = branch.bn.running_var
                      gamma = branch.bn.weight
                      beta = branch.bn.bias
                      eps = branch.bn.eps
                  else:
                      assert isinstance(branch, nn.BatchNorm2d)
                      if not hasattr(self, 'id_tensor'):
                          input_dim = self.in_channels // self.groups
                          kernel_value = torch.zeros((self.in_channels, 
                                                      input_dim, 
                                                      self.kernel_size, 
                                                      self.kernel_size),
                                                      dtype=branch.weight.dtype, 
                                                      device=branch.weight.device)
                          for i in range(self.in_channels):
                              kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
                          self.id_tensor = kernel_value
                      kernel = self.id_tensor
                      running_mean = branch.running_mean
                      running_var = branch.running_var
                      gamma = branch.weight
                      beta = branch.bias
                      eps = branch.eps
                  std = (running_var + eps).sqrt()
                  t = (gamma / std).reshape(-1111)
                  return kernel * t, beta - running_mean * gamma / std

              def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
                  mod_list = nn.Sequential()
                  mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels,
                                                        out_channels=self.out_channels,
                                                        kernel_size=kernel_size,
                                                        stride=self.stride,
                                                        padding=padding,
                                                        groups=self.groups,
                                                        bias=False))
                  mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))
                  return mod_list


          if __name__ == '__main__':
              model = MobileOneBlock(16163, padding=1, num_conv_branches=1)
              x = torch.ones(11699)
              y = model(x)
              torch.onnx.export(model,
                                (x,),
                                'mobileone_raw.onnx',
                                opset_version=12,
                                input_names=['input'],
                                output_names=['output'])
              model.reparameterize()
              torch.onnx.export(model,
                                (x,),
                                'mobileone_rep.onnx',
                                opset_version=12,
                                input_names=['input'],
                                output_names=['output'])

          話不多說,直接對(duì)比ONNX的輸出,就問你香不香?。?!

          4、參考

          [1].https://github.com/apple/ml-mobileone
          [2].An Improved One millisecond Mobile Backbone

          瀏覽 186
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  一道本一区二区视频 | 免费毛片网站高清无码在线观看 | 操逼五月天 | 日韩无码观看 | 亚洲国产精品久久久久久6q |