Pytorch中Spatial-Shift-Operation的5種實(shí)現(xiàn)策略

極市導(dǎo)讀
?作者通過參考一些使用空間偏移操作來替代區(qū)域卷及運(yùn)算的論文以及提供的核心代碼,整合了現(xiàn)有的知識(shí)歸納總結(jié)了五種實(shí)現(xiàn)策略。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿
原始文檔(可能會(huì)進(jìn)一步更新):https://www.yuque.com/lart/ugkv9f/nnor5p
前言
之前看了一些使用空間偏移操作來替代區(qū)域卷積運(yùn)算的論文:
粗看: https://www.yuque.com/lart/architecture/conv#uKY5N (CVPR 2018) [Grouped Shift] Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions: (ICCV 2019) 4-Connected Shift Residual Networks (NIPS 2018) [Active Shift] Constructing Fast Network through Deconstruction of Convolution (CVPR 2019) [Sparse Shift] All You Need Is a Few Shifts: Designing Efficient Convolutional Neural Networks for Image Classification 細(xì)看: Hire-MLP: Vision MLP via Hierarchical Rearrangement:https://www.yuque.com/lart/papers/lbhadn CycleMLP: A MLP-like Architecture for Dense Prediction:https://www.yuque.com/lart/papers/om3xb6 S2-MLP: Spatial-Shift MLP Architecture for Vision:https://www.yuque.com/lart/papers/dgdu2b S2-MLPv2: Improved Spatial-Shift MLP Architecture for Vision:https://www.yuque.com/lart/papers/dgdu2b
看完這些論文后, 通過參考他們提供的核心代碼(主要是后面那些MLP方法), 讓我對(duì)于實(shí)現(xiàn)空間偏移有了一些想法。
通過整合現(xiàn)有的知識(shí), 我歸納總結(jié)了五種實(shí)現(xiàn)策略。
由于我個(gè)人使用pytorch, 所以這里的展示也可能會(huì)用到pytorch自身提供的一些有用的函數(shù)。
問題描述
在提供實(shí)現(xiàn)之前,我們應(yīng)該先明確目的以便于后續(xù)的實(shí)現(xiàn)。這些現(xiàn)有的工作都可以簡(jiǎn)化為:
給定tensor , 這里遵循pytorch默認(rèn)的數(shù)據(jù)格式, 即 B, C, H, W .
通過變換操作, 將轉(zhuǎn)換為。
這里tensor , 為了提供合理的對(duì)比, 這里統(tǒng)一使用后面章節(jié)中基于"切片索引"策略的結(jié)果作為的值。
import?torch
xs?=?torch.meshgrid(torch.arange(5),?torch.arange(5))
x?=?torch.stack(xs,?dim=0)
x?=?x.unsqueeze(0).repeat(1,?4,?1,?1).float()
print(x)
'''
tensor([[[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]]]])
'''
方法1: 切片索引
這是最直接和簡(jiǎn)單的策略了,這也是S2-MLP系列中使用的策略。
我們將其作為其他所有策略的參考對(duì)象,后續(xù)的實(shí)現(xiàn)中同樣會(huì)得到這個(gè)結(jié)果。
direct_shift?=?torch.clone(x)
direct_shift[:,?0:2,?:,?1:]?=?torch.clone(direct_shift[:,?0:2,?:,?:4])
direct_shift[:,?2:4,?:,?:4]?=?torch.clone(direct_shift[:,?2:4,?:,?1:])
direct_shift[:,?4:6,?1:,?:]?=?torch.clone(direct_shift[:,?4:6,?:4,?:])
direct_shift[:,?6:8,?:4,?:]?=?torch.clone(direct_shift[:,?6:8,?1:,?:])
print(direct_shift)
'''
tensor([[[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]]]])
'''
方法2: 特征圖偏移—— torch.roll
pytorch提供了一個(gè)直接對(duì)特征圖進(jìn)行偏移的函數(shù),即 torch.roll . 這一操作在最近的transformer論文和mlp中有一些工作已經(jīng)開始使用,例如SwinTransformer和AS-MLP。
這里展示下AS-MLP論文中提供的偽代碼:

其主要作用就是將特征圖沿著某個(gè)軸向進(jìn)行偏移,并支持同時(shí)沿著多個(gè)軸向偏移,從而構(gòu)造更多樣的偏移方向。
為了實(shí)現(xiàn)與前面相同的結(jié)果,我們需要首先對(duì)輸入進(jìn)行padding。因?yàn)橹苯忧衅饕袀€(gè)特點(diǎn)就是邊界值是會(huì)重復(fù)出現(xiàn)的,而若是直接roll操作,會(huì)導(dǎo)致所有的值整體移動(dòng)。
所以為了實(shí)現(xiàn)類似的效果,先對(duì)四周各padding一個(gè)網(wǎng)格的數(shù)據(jù),注意這里選擇使用重復(fù)模式(replicate)以實(shí)現(xiàn)最終的邊界重復(fù)值的效果。
import?torch.nn.functional?as?F
pad_x?=?F.pad(x,?pad=[1,?1,?1,?1],?mode="replicate")??#?這里需要借助padding來保留邊界的數(shù)據(jù)
接下來開始處理,沿著四個(gè)方向各偏移一個(gè)單位的長(zhǎng)度:
roll_shift?=?torch.cat(
????[
????????torch.roll(pad_x[:,?c?*?2?:?(c?+?1)?*?2,?...],?shifts=(shift_h,?shift_w),?dims=(2,?3))
????????for?c,?(shift_h,?shift_w)?in?enumerate([(0,?1),?(0,?-1),?(1,?0),?(-1,?0)])
????],
????dim=1,
)
'''
tensor([[[[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.]],
?????????[[4.,?0.,?0.,?1.,?2.,?3.,?4.],
??????????[4.,?0.,?0.,?1.,?2.,?3.,?4.],
??????????[4.,?0.,?0.,?1.,?2.,?3.,?4.],
??????????[4.,?0.,?0.,?1.,?2.,?3.,?4.],
??????????[4.,?0.,?0.,?1.,?2.,?3.,?4.],
??????????[4.,?0.,?0.,?1.,?2.,?3.,?4.],
??????????[4.,?0.,?0.,?1.,?2.,?3.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.,?4.,?0.],
??????????[0.,?1.,?2.,?3.,?4.,?4.,?0.],
??????????[0.,?1.,?2.,?3.,?4.,?4.,?0.],
??????????[0.,?1.,?2.,?3.,?4.,?4.,?0.],
??????????[0.,?1.,?2.,?3.,?4.,?4.,?0.],
??????????[0.,?1.,?2.,?3.,?4.,?4.,?0.],
??????????[0.,?1.,?2.,?3.,?4.,?4.,?0.]],
?????????[[4.,?4.,?4.,?4.,?4.,?4.,?4.],
??????????[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.],
??????????[0.,?0.,?0.,?0.,?0.,?0.,?0.]],
?????????[[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.]]]])
'''
接下來只需要剪裁一下即可:
roll_shift?=?roll_shift[...,?1:6,?1:6]
print(roll_shift)
'''
tensor([[[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]]]])
'''
方法3: 1x1 Deformable Convolution—— ops.deform_conv2d
在閱讀Cycle FC的過程中,了解到了Deformable Convolution在實(shí)現(xiàn)空間偏移操作上的妙用。
由于torchvision最新版已經(jīng)集成了這一操作,所以我們只需要導(dǎo)入函數(shù)即可:
from?torchvision.ops?import?deform_conv2d
為了使用它實(shí)現(xiàn)空間偏移,我在對(duì)Cycle FC的解讀中,對(duì)相關(guān)代碼添加了一些注釋信息:
要想理解這一函數(shù)的操作,需要首先理解后面使用的deform_conv2d_tv的具體用法。
具體可見:https://pytorch.org/vision/0.10/ops.html#torchvision.ops.deform_conv2d
這里對(duì)于offset參數(shù)的要求是:
offset (Tensor[batch_size, 2 _ offset_groups _ kernel_height * kernel_width, out_height, out_width])
offsets to be applied for each position in the convolution kernel.
也就是說, 對(duì)于樣本
s的輸出特征圖的通道c中的位置(x, y), 這個(gè)函數(shù)會(huì)從offset中取出, 形狀為kernel_height*kernel_width?的卷積核所對(duì)應(yīng)的偏移參數(shù), 其為offset[s, 0:2*offset_groups*kernel_height*kernel_width, x, y]。也就是這一系列參數(shù)都是對(duì)應(yīng)樣本s的單個(gè)位置(x, y)的。
針對(duì)不同的位置可以有不同的
offset, 也可以有相同的 (下面的實(shí)現(xiàn)就是后者)。對(duì)于這2*offset_groups*kernel_height*kernel_width個(gè)數(shù), 涉及到對(duì)于輸入特征通道的分組。將其分成offset_groups組, 每份單獨(dú)擁有一組對(duì)應(yīng)于卷積核中心位置的相對(duì)偏移量, 共2*kernel_height*kernel_width個(gè)數(shù)。
對(duì)于每個(gè)核參數(shù),使用兩個(gè)量來描述偏移, 即h方向和w方向相對(duì)中心位置的偏移,即對(duì)應(yīng)于后面代碼中的減去
kernel_height//2或者kernel_width//2。
需要注意的是,當(dāng)偏移位置位于
padding后的tensor的邊界之外,則是將網(wǎng)格使用0補(bǔ)齊。如果網(wǎng)格上有邊界值,則使用邊界值和用0補(bǔ)齊的網(wǎng)格頂點(diǎn)來計(jì)算雙線性插值的結(jié)果.
該策略需要我們?nèi)?gòu)造特定的相對(duì)偏移值offset來對(duì)1x1卷積核在不同通道的采樣位置進(jìn)行調(diào)整。
我們先構(gòu)造我們需要的offset 。這里之所以將 out_height & out_width 兩個(gè)維度設(shè)置為1, 是因?yàn)槲覀儗?duì)整個(gè)空間的偏移是一致的,所以只需要簡(jiǎn)單的重復(fù)數(shù)值即可。
offset?=?torch.empty(1,?2?*?8?*?1?*?1,?1,?1)
for?c,?(rel_offset_h,?rel_offset_w)?in?enumerate([(0,?-1),?(0,?-1),?(0,?1),?(0,?1),?(-1,?0),?(-1,?0),?(1,?0),?(1,?0)]):
????offset[0,?c?*?2?+?0,?0,?0]?=?rel_offset_h
????offset[0,?c?*?2?+?1,?0,?0]?=?rel_offset_w
offset?=?offset.repeat(1,?1,?7,?7).float()??#?針對(duì)空間偏移重復(fù)偏移量
在構(gòu)造offset的時(shí)候,我們要明確,其通道中的數(shù)據(jù)都是兩兩一組的,每一組包含著沿著H軸和W軸的相對(duì)偏移量 (這一相對(duì)偏移量應(yīng)該是以其作用的卷積權(quán)重位置為中心 —— 這一結(jié)論我并沒有驗(yàn)證,只是個(gè)人的推理,因?yàn)檫@樣可能在源碼中實(shí)現(xiàn)起來更加方便,可以直接作用權(quán)重對(duì)應(yīng)位置的坐標(biāo)。在不讀源碼的前提下理解函數(shù)的功能,那就需要自行構(gòu)造數(shù)據(jù)來驗(yàn)證性的理解了)。
為了更好的理解offset的作用的原理,我們可以想象對(duì)于采樣位置,使用相對(duì)偏移量作用后,采樣位置變成了。即原來作用于的權(quán)重,偏移后直接作用到了位置上。
對(duì)于我們的前面描述的沿著四個(gè)軸向各自一個(gè)單位偏移,可以通過對(duì)和分別賦予中的值即可實(shí)現(xiàn)。
由于這里僅需要體現(xiàn)通道特定的空間偏移作用,而并不需要Deformable Convolution的卷積功能, 我們需要將卷積核設(shè)置為單位矩陣,并轉(zhuǎn)換為分組卷積對(duì)應(yīng)的卷積核的形式:
weight?=?torch.eye(8).reshape(8,?8,?1,?1).float()
#?輸入8通道,輸出8通道,每個(gè)輸入通道只和一個(gè)對(duì)應(yīng)的輸出通道有映射權(quán)值1
接下來將權(quán)重和偏移送入導(dǎo)入的函數(shù)中。
由于該函數(shù)對(duì)于偏移超出邊界的位置是使用0補(bǔ)齊的網(wǎng)格計(jì)算的,所以為了實(shí)現(xiàn)前面邊界上的重復(fù)值的效果,這里同樣需要使用重復(fù)模式下的padding后的輸入。 并對(duì)結(jié)果進(jìn)行一下修剪:
deconv_shift?=?deform_conv2d(pad_x,?offset=offset,?weight=weight)
deconv_shift?=?deconv_shift[...,?1:6,?1:6]
print(deconv_shift)
'''
tensor([[[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]]]])
'''
方法4: 3x3 Depthwise Convolution—— F.conv2d
在S2MLP中提到了空間偏移操作可以通過使用特殊構(gòu)造的3x3 Depthwise Convolution來實(shí)現(xiàn)。
由于基于3x3卷積操作,所以為了實(shí)現(xiàn)邊界值的重復(fù)效果仍然需要對(duì)輸入進(jìn)行重復(fù)padding。
首先構(gòu)造對(duì)應(yīng)四個(gè)方向的卷積核:
k1?=?torch.FloatTensor([[0,?0,?0],?[1,?0,?0],?[0,?0,?0]]).reshape(1,?1,?3,?3)
k2?=?torch.FloatTensor([[0,?0,?0],?[0,?0,?1],?[0,?0,?0]]).reshape(1,?1,?3,?3)
k3?=?torch.FloatTensor([[0,?1,?0],?[0,?0,?0],?[0,?0,?0]]).reshape(1,?1,?3,?3)
k4?=?torch.FloatTensor([[0,?0,?0],?[0,?0,?0],?[0,?1,?0]]).reshape(1,?1,?3,?3)
weight?=?torch.cat([k1,?k1,?k2,?k2,?k3,?k3,?k4,?k4],?dim=0)??#?每個(gè)輸出通道對(duì)應(yīng)一個(gè)輸入通道
接下來將卷積核和數(shù)據(jù)送入 F.conv2d 中計(jì)算即可,輸入在四邊各padding了1個(gè)單位,所以輸出形狀不變:
conv_shift?=?F.conv2d(pad_x,?weight=weight,?groups=8)
print(conv_shift)
'''
tensor([[[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]]]])
'''
方法5: 網(wǎng)格采樣—— F.grid_sample
最后這里提到的基于 F.grid_sample ,該操作是pytorch提供的用于構(gòu)建STN的一個(gè)函數(shù),但是其在光流預(yù)測(cè)任務(wù)以及最近的一些分割任務(wù)中開始出現(xiàn):
AlignSeg: Feature-Aligned Segmentation Networks Semantic Flow for Fast and Accurate Scene Parsing
針對(duì)4Dtensor,其主要作用就是根據(jù)給定的網(wǎng)格采樣圖grid來對(duì)數(shù)據(jù)點(diǎn)進(jìn)行采樣以放置到輸出的位置中。
要注意的是,該函數(shù)限制了采樣圖grid的取值范圍是對(duì)輸入的尺寸歸一化后的結(jié)果,并且的最后一維度分別是在索引W軸、H軸。即對(duì)于輸入tensor的布局 B, C, H, W 的四個(gè)維度從后往前索引。實(shí)際上,這一規(guī)則在pytorch的其他函數(shù)的設(shè)計(jì)中廣泛遵循。例如pytorch中的pad函數(shù)的規(guī)則也是一樣的。
首先根據(jù)需求構(gòu)造基于輸入數(shù)據(jù)的原始坐標(biāo)數(shù)組 (左上角為,右上角為):
h_coord,?w_coord?=?torch.meshgrid(torch.arange(5),?torch.arange(5))
print(h_coord)
print(w_coord)
h_coord?=?h_coord.reshape(1,?5,?5,?1)
w_coord?=?w_coord.reshape(1,?5,?5,?1)
'''
tensor([[0,?0,?0,?0,?0],
????????[1,?1,?1,?1,?1],
????????[2,?2,?2,?2,?2],
????????[3,?3,?3,?3,?3],
????????[4,?4,?4,?4,?4]])
tensor([[0,?1,?2,?3,?4],
????????[0,?1,?2,?3,?4],
????????[0,?1,?2,?3,?4],
????????[0,?1,?2,?3,?4],
????????[0,?1,?2,?3,?4]])
'''
針對(duì)每一個(gè)輸出
????????????torch.cat(
????????????????[??#?請(qǐng)注意這里的堆疊順序,先放靠后的軸的坐標(biāo)
????????????????????2?*?torch.clamp(w_coord?+?w,?0,?4)?/?(5?-?1)?-?1,
????????????????????2?*?torch.clamp(h_coord?+?h,?0,?4)?/?(5?-?1)?-?1,
????????????????],
????????????????dim=-1,
????????????)
這里的參數(shù)表示基于原始坐標(biāo)系的偏移量。
由于這里直接使用clamp限制了采樣區(qū)間,靠近邊界的部分會(huì)重復(fù)使用,所以后續(xù)直接使用原始的輸入即可。
將新坐標(biāo)送入函數(shù)的時(shí)候,需要將其轉(zhuǎn)換為范圍內(nèi)的值,即針對(duì)輸入的形狀W和H進(jìn)行歸一化計(jì)算。
????????F.grid_sample(
????????????x,
????????????torch.cat(
????????????????[
????????????????????2?*?torch.clamp(w_coord?+?w,?0,?4)?/?(5?-?1)?-?1,
????????????????????2?*?torch.clamp(h_coord?+?h,?0,?4)?/?(5?-?1)?-?1,
????????????????],
????????????????dim=-1,
????????????),
????????????mode="bilinear",
????????????align_corners=True,
????????)
要注意,這里使用的是 align_corners=True ,關(guān)于pytorch中該參數(shù)的介紹可以查看https://www.yuque.com/lart/idh721/ugwn46。
True :
False :
所以可以看到,這里前者更符合我們的需求,因?yàn)檫@里提到的涉及雙線性插值的算法(例如前面的Deformable Convolution)的實(shí)現(xiàn)都是將像素放到網(wǎng)格頂點(diǎn)上的 (按照這一思路理解比較符合實(shí)驗(yàn)現(xiàn)象,我就姑且這樣描述)。
grid_sampled_shift?=?torch.cat(
????[
????????F.grid_sample(
????????????x,
????????????torch.cat(
????????????????[
????????????????????2?*?torch.clamp(w_coord?+?w,?0,?4)?/?(5?-?1)?-?1,
????????????????????2?*?torch.clamp(h_coord?+?h,?0,?4)?/?(5?-?1)?-?1,
????????????????],
????????????????dim=-1,
????????????),
????????????mode="bilinear",
????????????align_corners=True,
????????)
????????for?x,?(h,?w)?in?zip(x.chunk(4,?dim=1),?[(0,?-1),?(0,?1),?(-1,?0),?(1,?0)])
????],
????dim=1,
)
print(grid_sampled_shift)
'''
tensor([[[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.],
??????????[0.,?0.,?1.,?2.,?3.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.],
??????????[1.,?2.,?3.,?4.,?4.]],
?????????[[0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]],
?????????[[1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.],
??????????[0.,?1.,?2.,?3.,?4.]]]])
'''
另外的一些思考
關(guān)于 F.grid_sample 的誤差問題
由于 F.grid_sample 涉及到歸一化操作,自然而然存在精度損失。所以實(shí)際上如果想要實(shí)現(xiàn)精確控制的話,不太建議使用這個(gè)方法。
如果位置恰好在但單元格角點(diǎn)上,倒是可以使用最近鄰插值的模式來獲得一個(gè)更加整齊的結(jié)果。
下面是一個(gè)例子:
h_coord,?w_coord?=?torch.meshgrid(torch.arange(7),?torch.arange(7))
h_coord?=?h_coord.reshape(1,?7,?7,?1)
w_coord?=?w_coord.reshape(1,?7,?7,?1)
grid?=?torch.cat(
????[
????????2?*?torch.clamp(w_coord,?0,?6)?/?(7?-?1)?-?1,
????????2?*?torch.clamp(h_coord,?0,?6)?/?(7?-?1)?-?1,
????],
????dim=-1,
)
print(grid)
print(pad_x[:,?:2])
print("mode=bilinear\n",?F.grid_sample(pad_x[:,?:2],?grid,?mode="bilinear",?align_corners=True))
print("mode=nearest\n",?F.grid_sample(pad_x[:,?:2],?grid,?mode="nearest",?align_corners=True))
'''
tensor([[[[-1.0000,?-1.0000],
??????????[-0.6667,?-1.0000],
??????????[-0.3333,?-1.0000],
??????????[?0.0000,?-1.0000],
??????????[?0.3333,?-1.0000],
??????????[?0.6667,?-1.0000],
??????????[?1.0000,?-1.0000]],
?????????[[-1.0000,?-0.6667],
??????????[-0.6667,?-0.6667],
??????????[-0.3333,?-0.6667],
??????????[?0.0000,?-0.6667],
??????????[?0.3333,?-0.6667],
??????????[?0.6667,?-0.6667],
??????????[?1.0000,?-0.6667]],
?????????[[-1.0000,?-0.3333],
??????????[-0.6667,?-0.3333],
??????????[-0.3333,?-0.3333],
??????????[?0.0000,?-0.3333],
??????????[?0.3333,?-0.3333],
??????????[?0.6667,?-0.3333],
??????????[?1.0000,?-0.3333]],
?????????[[-1.0000,??0.0000],
??????????[-0.6667,??0.0000],
??????????[-0.3333,??0.0000],
??????????[?0.0000,??0.0000],
??????????[?0.3333,??0.0000],
??????????[?0.6667,??0.0000],
??????????[?1.0000,??0.0000]],
?????????[[-1.0000,??0.3333],
??????????[-0.6667,??0.3333],
??????????[-0.3333,??0.3333],
??????????[?0.0000,??0.3333],
??????????[?0.3333,??0.3333],
??????????[?0.6667,??0.3333],
??????????[?1.0000,??0.3333]],
?????????[[-1.0000,??0.6667],
??????????[-0.6667,??0.6667],
??????????[-0.3333,??0.6667],
??????????[?0.0000,??0.6667],
??????????[?0.3333,??0.6667],
??????????[?0.6667,??0.6667],
??????????[?1.0000,??0.6667]],
?????????[[-1.0000,??1.0000],
??????????[-0.6667,??1.0000],
??????????[-0.3333,??1.0000],
??????????[?0.0000,??1.0000],
??????????[?0.3333,??1.0000],
??????????[?0.6667,??1.0000],
??????????[?1.0000,??1.0000]]]])
tensor([[[[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.]]]])
mode=bilinear
?tensor([[[[0.0000e+00,?0.0000e+00,?0.0000e+00,?0.0000e+00,?0.0000e+00,
???????????0.0000e+00,?0.0000e+00],
??????????[1.1921e-07,?1.1921e-07,?1.1921e-07,?1.1921e-07,?1.1921e-07,
???????????1.1921e-07,?1.1921e-07],
??????????[1.0000e+00,?1.0000e+00,?1.0000e+00,?1.0000e+00,?1.0000e+00,
???????????1.0000e+00,?1.0000e+00],
??????????[2.0000e+00,?2.0000e+00,?2.0000e+00,?2.0000e+00,?2.0000e+00,
???????????2.0000e+00,?2.0000e+00],
??????????[3.0000e+00,?3.0000e+00,?3.0000e+00,?3.0000e+00,?3.0000e+00,
???????????3.0000e+00,?3.0000e+00],
??????????[4.0000e+00,?4.0000e+00,?4.0000e+00,?4.0000e+00,?4.0000e+00,
???????????4.0000e+00,?4.0000e+00],
??????????[4.0000e+00,?4.0000e+00,?4.0000e+00,?4.0000e+00,?4.0000e+00,
???????????4.0000e+00,?4.0000e+00]],
?????????[[0.0000e+00,?1.1921e-07,?1.0000e+00,?2.0000e+00,?3.0000e+00,
???????????4.0000e+00,?4.0000e+00],
??????????[0.0000e+00,?1.1921e-07,?1.0000e+00,?2.0000e+00,?3.0000e+00,
???????????4.0000e+00,?4.0000e+00],
??????????[0.0000e+00,?1.1921e-07,?1.0000e+00,?2.0000e+00,?3.0000e+00,
???????????4.0000e+00,?4.0000e+00],
??????????[0.0000e+00,?1.1921e-07,?1.0000e+00,?2.0000e+00,?3.0000e+00,
???????????4.0000e+00,?4.0000e+00],
??????????[0.0000e+00,?1.1921e-07,?1.0000e+00,?2.0000e+00,?3.0000e+00,
???????????4.0000e+00,?4.0000e+00],
??????????[0.0000e+00,?1.1921e-07,?1.0000e+00,?2.0000e+00,?3.0000e+00,
???????????4.0000e+00,?4.0000e+00],
??????????[0.0000e+00,?1.1921e-07,?1.0000e+00,?2.0000e+00,?3.0000e+00,
???????????4.0000e+00,?4.0000e+00]]]])
mode=nearest
?tensor([[[[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[0.,?0.,?0.,?0.,?0.,?0.,?0.],
??????????[1.,?1.,?1.,?1.,?1.,?1.,?1.],
??????????[2.,?2.,?2.,?2.,?2.,?2.,?2.],
??????????[3.,?3.,?3.,?3.,?3.,?3.,?3.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.],
??????????[4.,?4.,?4.,?4.,?4.,?4.,?4.]],
?????????[[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.],
??????????[0.,?0.,?1.,?2.,?3.,?4.,?4.]]]])
'''
F.grid_sample 與Deformable Convolution的關(guān)系
雖然二者都實(shí)現(xiàn)了對(duì)于輸入與輸出位置映射關(guān)系的調(diào)整,但是二者調(diào)整的方式有著明顯的差別。
參考坐標(biāo)系不同 前者的坐標(biāo)系是基于整體輸入的一個(gè)歸一化坐標(biāo)系, 原點(diǎn)為輸入的HW平面的中心位置, H軸和W軸分別以向下和向右為正向. 而在坐標(biāo)系WOH中, 輸入數(shù)據(jù)的左上角為, 右上角為. 后者的坐標(biāo)系是相對(duì)于權(quán)重初始作用位置的相對(duì)坐標(biāo)系. 但是實(shí)際上, 這里其實(shí)理解為沿著H軸和W軸的_相對(duì)偏移量_更為合適. 例如, 將權(quán)重作用位置向左偏移一個(gè)單位, 實(shí)際上讓其對(duì)應(yīng)的偏移參數(shù)組取值為即可, 即將作用位置相對(duì)于原始作用位置的坐標(biāo)加上個(gè). 作用效果不同 前者直接對(duì)整體輸入進(jìn)行坐標(biāo)調(diào)整, 對(duì)于輸入的所有通道具有相同的調(diào)整效果. 后者由于構(gòu)建于卷積操作之上, 所以可以更加方便的處理不同通道( offset_groups)、不同的實(shí)際上可能有重疊的局部區(qū)域(kernel_height * kernel_width). 所以實(shí)際功能更加靈活和可調(diào)整.
Shift操作的第二春
雖然在之前的工作中已經(jīng)探索了多種空間shift操作的形式,但是卻并沒有引起太多的關(guān)注。
(CVPR 2018) [Grouped Shift] Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions (ICCV 2019) 4-Connected Shift Residual Networks (NIPS 2018) [Active Shift] Constructing Fast Network through Deconstruction of Convolution (CVPR 2019) [Sparse Shift] All You Need Is a Few Shifts: Designing Efficient Convolutional Neural Networks for Image Classification
這些工作大多專注于輕量化網(wǎng)絡(luò)的設(shè)計(jì),而現(xiàn)在的這些基于shift的方法,則結(jié)合了MLP這一快船,好像又激起了一些新的水花。
當(dāng)前的這些方法,往往會(huì)采用更有效的訓(xùn)練設(shè)定,這些模型之外的策略在一定程度上也極大的提升了模型的表現(xiàn)。這其實(shí)也會(huì)讓人疑惑,如果直接遷移之前的那些shift操作到這里的MLP框架中,或許性能也不會(huì)差吧?
這一想法其實(shí)也適用于傳統(tǒng)的CNN方法,之前的那些結(jié)構(gòu)如果使用相同的訓(xùn)練策略,相比現(xiàn)在,到底能差多少?這估計(jì)只能那些有卡有時(shí)間有耐心的大佬們能夠一探究竟了。
實(shí)際上綜合來看,現(xiàn)有的這些基于空間偏移的MLP的方法,更可以看作是 (NIPS 2018) [Active Shift] Constructing Fast Network through Deconstruction of Convolution(https://www.yuque.com/lart/architecture/conv#tjP7f) 這篇工作的特化版本。

也就是將原本這篇工作中的自適應(yīng)學(xué)習(xí)的偏移參數(shù)改成了固定的偏移參數(shù)。
如果覺得有用,就請(qǐng)分享到朋友圈吧!
公眾號(hào)后臺(tái)回復(fù)“transformer”獲取最新Transformer綜述論文下載~

#?極市平臺(tái)簽約作者#
Lart
知乎:人民藝術(shù)家
CSDN:有為少年
大連理工大學(xué)在讀博士
研究領(lǐng)域:主要方向?yàn)閳D像分割,但多從事于二值圖像分割的研究。也會(huì)關(guān)注其他領(lǐng)域,例如分類和檢測(cè)等方向的發(fā)展。
作品精選


