重參系列 | 輕量化模型+重參技術(shù)是不是可以起飛?
1、開篇小記
1.1、知識(shí)點(diǎn)1

相同的架構(gòu),不同激活函數(shù)帶來的延遲差異極大。這里Mobileone選擇的是使用ReLU。
1.2、知識(shí)點(diǎn)2

當(dāng)采用單分支結(jié)構(gòu)時(shí),模型具有更快的速度。這個(gè)在RepVGG時(shí)就已經(jīng)知道了,這里讀者可以參考RepVGG筆記。
2、MobileOne 簡(jiǎn)述
MobileOne 的核心模塊基于 MobileNetV1 而設(shè)計(jì),同時(shí)吸收了重參數(shù)思想,得到上圖所示的結(jié)構(gòu)。注:這里的重參數(shù)機(jī)制還存在一個(gè)超參k用于控制重參數(shù)分支的數(shù)量(實(shí)驗(yàn)表明:對(duì)于小模型來說,該變種收益更大)。

通過上圖,如果你愿意,其實(shí)就是DBB+RepVGG的結(jié)合,而分支數(shù)你可以隨意的擴(kuò)寬,重參的化直接進(jìn)行weight與bias的合并即可。
3、MobileOne 的實(shí)現(xiàn)
以下是 MobileOne 的Pytorch實(shí)現(xiàn):
from typing import Optional, List, Tuple
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
class MobileOneBlock(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
inference_mode: bool = False,
use_se: bool = False,
num_conv_branches: int = 3) -> None:
super(MobileOneBlock, self).__init__()
self.inference_mode = inference_mode
self.groups = groups
self.stride = stride
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_conv_branches = num_conv_branches
# Check if SE-ReLU is requested
self.se = nn.Identity()
self.activation = nn.ReLU()
if inference_mode:
self.reparam_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True)
else:
# skip connection
self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
# 3x3 conv branches
rbr_conv = list()
for _ in range(self.num_conv_branches):
rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding))
self.rbr_conv = nn.ModuleList(rbr_conv)
# 1x1 conv branch(scale branch)
self.rbr_scale = None
if kernel_size > 1:
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Apply forward pass. """
if self.inference_mode:
return self.activation(self.se(self.reparam_conv(x)))
identity_out = 0
if self.rbr_skip is not None:
identity_out = self.rbr_skip(x)
# Scale branch output
scale_out = 0
if self.rbr_scale is not None:
scale_out = self.rbr_scale(x)
# Other branches
out = scale_out + identity_out
for ix in range(self.num_conv_branches):
out += self.rbr_conv[ix](x)
return self.activation(self.se(out))
def reparameterize(self):
if self.inference_mode:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2d(in_channels=self.rbr_conv[0].conv.in_channels,
out_channels=self.rbr_conv[0].conv.out_channels,
kernel_size=self.rbr_conv[0].conv.kernel_size,
stride=self.rbr_conv[0].conv.stride,
padding=self.rbr_conv[0].conv.padding,
dilation=self.rbr_conv[0].conv.dilation,
groups=self.rbr_conv[0].conv.groups,
bias=True)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('rbr_conv')
self.__delattr__('rbr_scale')
if hasattr(self, 'rbr_skip'):
self.__delattr__('rbr_skip')
self.inference_mode = True
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
# 獲取scale分支的卷積核bias
kernel_scale = 0
bias_scale = 0
if self.rbr_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
# 將scale分支Pad為卷積分支
pad = self.kernel_size // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
# 獲取 skip 分支的權(quán)重
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
# 獲取卷積分支的權(quán)重
kernel_conv = 0
bias_conv = 0
for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
kernel_conv += _kernel
bias_conv += _bias
kernel_final = kernel_conv + kernel_scale + kernel_identity
bias_final = bias_conv + bias_scale + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = torch.zeros((self.in_channels,
input_dim,
self.kernel_size,
self.kernel_size),
dtype=branch.weight.dtype,
device=branch.weight.device)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
mod_list = nn.Sequential()
mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=padding,
groups=self.groups,
bias=False))
mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))
return mod_list
if __name__ == '__main__':
model = MobileOneBlock(16, 16, 3, padding=1, num_conv_branches=1)
x = torch.ones(1, 16, 9, 9)
y = model(x)
torch.onnx.export(model,
(x,),
'mobileone_raw.onnx',
opset_version=12,
input_names=['input'],
output_names=['output'])
model.reparameterize()
torch.onnx.export(model,
(x,),
'mobileone_rep.onnx',
opset_version=12,
input_names=['input'],
output_names=['output'])
話不多說,直接對(duì)比ONNX的輸出,就問你香不香?。?!

4、參考
[1].https://github.com/apple/ml-mobileone
[2].An Improved One millisecond Mobile Backbone
評(píng)論
圖片
表情
