金字塔ViT | 華為提出使用金字塔結構改進Transformer,漲點明顯(Pytorch逐行解讀)


Transformer在計算機視覺任務方面取得了很大的進展。Transformer-in-Transformer (TNT)體系結構利用內部Transformer和外部Transformer來提取局部和全局表示。在這項工作中,通過引入2種先進的設計來提出新的TNT Baseline:
Pyramid Architecture
Convolutional Stem
新的“PyramidTNT”通過建立層次表示,顯著地改進了原來的TNT。PyramidTNT相較于之前最先進的Vision Transformer具有更好的性能,如Swin-Transformer。
1簡介
Vision Transformer為計算機視覺提供了一種新的解決思路。從ViT開始,提出了一系列改進Vision Transformer體系結構的工作。
PVT介紹了Vision Transformer的金字塔網(wǎng)絡體系結構
T2T-ViT-14 遞歸地將相鄰的Token聚合為一個Token,以提取局部結構,減少Token的數(shù)量
TNT 利用 inner Transformer和outer Transformer來建模 word-level 和 sentence-level 的視覺表示
Swin-Transformer提出了一種分層Transformer,其表示由Shifted windows來進行計算
隨著近年來的研究進展,Vision Transformer的性能已經(jīng)可以優(yōu)于卷積神經(jīng)網(wǎng)絡(CNN)。而本文的這項工作是建立了基于TNT框架的改進的 Vision Transformer Baseline。這里主要引入了兩個主要的架構修改:
Pyramid Architecture:逐漸降低分辨率,提取多尺度表示
Convolutional Stem:修補Stem和穩(wěn)定訓練
這里作者還使用了幾個其他技巧來進一步提高效率。新的Transformer被命名為PyramidTNT。
對圖像分類和目標檢測的實驗證明了金字塔檢測的優(yōu)越性。具體來說,PyramidTNT-S在只有3.3B FLOPs的情況下獲得了82.0%的ImageNet分類準確率,明顯優(yōu)于原來的TNT-S和Swin-T。
對于COCO檢測,PyramidTNT-S比現(xiàn)有的Transformer和MLP檢測模型以更少的計算成本實現(xiàn)42.0的mAP。
2本文方法
2.1 Convolutional Stem
給定一個輸入圖像,TNT模型首先將圖像分割成多個patch,并進一步將每個patch視為一個sub-patch序列。然后應用線性層將sub-patch投射到visual word vector(又稱token)。這些視覺word被拼接在一起并轉換成一個visual sentence vector。
肖奧等人發(fā)現(xiàn)在ViT中使用多個卷積作為Stem可以提高優(yōu)化穩(wěn)定性,也能提高性能。在此基礎上,本文構造了一個金字塔的卷積Stem。利用3×3卷積的堆棧產(chǎn)生visual word vector ,其中C是visual word vector的維度。同樣也可以得到visual sentence vector ,其中D是visual sentence vector 的維度。word-level 和 sentence-level位置編碼分別添加到visual words和sentences上,和原始的TNT一樣。
class?Stem(nn.Module):
????"""?
????Image?to?Visual?Word?Embedding
????"""
????def?__init__(self,?img_size=224,?in_chans=3,?outer_dim=768,?inner_dim=24):
????????super().__init__()
????????img_size?=?to_2tuple(img_size)
????????self.img_size?=?img_size
????????self.inner_dim?=?inner_dim
????????self.num_patches?=?img_size[0]?//?8?*?img_size[1]?//?8
????????self.num_words?=?16
????????
????????self.common_conv?=?nn.Sequential(
????????????nn.Conv2d(in_chans,?inner_dim*2,?3,?stride=2,?padding=1),
????????????nn.BatchNorm2d(inner_dim*2),
????????????nn.ReLU(inplace=True),
????????)
????????#?利用?inner?Transformer來建模?word-level
????????self.inner_convs?=?nn.Sequential(
????????????nn.Conv2d(inner_dim*2,?inner_dim,?3,?stride=1,?padding=1),
????????????nn.BatchNorm2d(inner_dim),
????????????nn.ReLU(inplace=False),
????????)
????????#?利用outer?Transformer來建模?sentence-level?的視覺表示
????????self.outer_convs?=?nn.Sequential(
????????????nn.Conv2d(inner_dim*2,?inner_dim*4,?3,?stride=2,?padding=1),
????????????nn.BatchNorm2d(inner_dim*4),
????????????nn.ReLU(inplace=True),
????????????nn.Conv2d(inner_dim*4,?inner_dim*8,?3,?stride=2,?padding=1),
????????????nn.BatchNorm2d(inner_dim*8),
????????????nn.ReLU(inplace=True),
????????????nn.Conv2d(inner_dim*8,?outer_dim,?3,?stride=1,?padding=1),
????????????nn.BatchNorm2d(outer_dim),
????????????nn.ReLU(inplace=False),
????????)
????????
????????self.unfold?=?nn.Unfold(kernel_size=4,?padding=0,?stride=4)
????def?forward(self,?x):
????????B,?C,?H,?W?=?x.shape
????????H_out,?W_out?=?H?//?8,?W?//?8
????????H_in,?W_in?=?4,?4
????????x?=?self.common_conv(x)
????????#?inner_tokens建模word?level表征
????????inner_tokens?=?self.inner_convs(x)?#?B,?C,?H,?W
????????inner_tokens?=?self.unfold(inner_tokens).transpose(1,?2)?#?B,?N,?Ck2
????????inner_tokens?=?inner_tokens.reshape(B?*?H_out?*?W_out,?self.inner_dim,?H_in*W_in).transpose(1,?2)?#?B*N,?C,?4*4
????????#?outer_tokens建模?sentence?level表征
????????outer_tokens?=?self.outer_convs(x)?#?B,?C,?H_out,?W_out
????????outer_tokens?=?outer_tokens.permute(0,?2,?3,?1).reshape(B,?H_out?*?W_out,?-1)
????????return?inner_tokens,?outer_tokens,?(H_out,?W_out),?(H_in,?W_in)
2.2 ?Pyramid Architecture
原始的TNT網(wǎng)絡在繼ViT之后的每個塊中保持相同數(shù)量的token。visual words和visual sentences的數(shù)量從下到上保持不變。
本文受PVT的啟發(fā),為TNT構建了4個不同數(shù)量的Token階段,如圖1(b)。所示在這4個階段中,visual words的空間形狀分別設置為H/2×W/2、H/4×W/4、H/8×W/8、H/16×W/16;visual sentences的空間形狀分別設置為H/8×W/8、H/16×W/16、H/32×W/32、H/64×W/64。下采樣操作是通過stride=2的卷積來實現(xiàn)的。每個階段由幾個TNT塊組成,TNT塊在word-level 和 sentence-level特征上操作。最后,利用全局平均池化操作,將輸出的visual sentences融合成一個向量作為圖像表示。

class?SentenceAggregation(nn.Module):
????"""?
????Sentence?Aggregation
????"""
????def?__init__(self,?dim_in,?dim_out,?stride=2,?act_layer=nn.GELU):
????????super().__init__()
????????self.stride?=?stride
????????self.norm?=?nn.LayerNorm(dim_in)
????????self.conv?=?nn.Sequential(
????????????nn.Conv2d(dim_in,?dim_out,?kernel_size=2*stride-1,?padding=stride-1,?stride=stride),
????????)
????????
????def?forward(self,?x,?H,?W):
????????B,?N,?C?=?x.shape?#?B,?N,?C
????????x?=?self.norm(x)
????????x?=?x.transpose(1,?2).reshape(B,?C,?H,?W)
????????x?=?self.conv(x)
????????H,?W?=?math.ceil(H?/?self.stride),?math.ceil(W?/?self.stride)
????????x?=?x.reshape(B,?-1,?H?*?W).transpose(1,?2)
????????return?x,?H,?W
class?WordAggregation(nn.Module):
????"""?
????Word?Aggregation
????"""
????def?__init__(self,?dim_in,?dim_out,?stride=2,?act_layer=nn.GELU):
????????super().__init__()
????????self.stride?=?stride
????????self.dim_out?=?dim_out
????????self.norm?=?nn.LayerNorm(dim_in)
????????self.conv?=?nn.Sequential(
????????????nn.Conv2d(dim_in,?dim_out,?kernel_size=2*stride-1,?padding=stride-1,?stride=stride),
????????)
????def?forward(self,?x,?H_out,?W_out,?H_in,?W_in):
????????B_N,?M,?C?=?x.shape?#?B*N,?M,?C
????????x?=?self.norm(x)
????????x?=?x.reshape(-1,?H_out,?W_out,?H_in,?W_in,?C)
????????
????????#?padding?to?fit?(1333,?800)?in?detection.
????????pad_input?=?(H_out?%?2?==?1)?or?(W_out?%?2?==?1)
????????if?pad_input:
????????????x?=?F.pad(x.permute(0,?3,?4,?5,?1,?2),?(0,?W_out?%?2,?0,?H_out?%?2))
????????????x?=?x.permute(0,?4,?5,?1,?2,?3)????????????
????????#?patch?merge
????????x1?=?x[:,?0::2,?0::2,?:,?:,?:]??#?B,?H/2,?W/2,?H_in,?W_in,?C
????????x2?=?x[:,?1::2,?0::2,?:,?:,?:]
????????x3?=?x[:,?0::2,?1::2,?:,?:,?:]
????????x4?=?x[:,?1::2,?1::2,?:,?:,?:]
????????x?=?torch.cat([torch.cat([x1,?x2],?3),?torch.cat([x3,?x4],?3)],?4)?#?B,?H/2,?W/2,?2*H_in,?2*W_in,?C
????????x?=?x.reshape(-1,?2*H_in,?2*W_in,?C).permute(0,?3,?1,?2)?#?B_N/4,?C,?2*H_in,?2*W_in
????????x?=?self.conv(x)??#?B_N/4,?C,?H_in,?W_in
????????x?=?x.reshape(-1,?self.dim_out,?M).transpose(1,?2)
????????return?x
????
class?Stage(nn.Module):
????"""?
????PyramidTNT?stage
????"""
????def?__init__(self,?num_blocks,?outer_dim,?inner_dim,?outer_head,?inner_head,?num_patches,?num_words,?mlp_ratio=4.,
?????????????????qkv_bias=False,?qk_scale=None,?drop=0.,?attn_drop=0.,?drop_path=0.,?act_layer=nn.GELU,
?????????????????norm_layer=nn.LayerNorm,?se=0,?sr_ratio=1):
????????super().__init__()
????????blocks?=?[]
????????drop_path?=?drop_path?if?isinstance(drop_path,?list)?else?[drop_path]?*?num_blocks
????????
????????for?j?in?range(num_blocks):
????????????if?j?==?0:
????????????????_inner_dim?=?inner_dim
????????????elif?j?==?1?and?num_blocks?>?6:
????????????????_inner_dim?=?inner_dim
????????????else:
????????????????_inner_dim?=?-1
????????????blocks.append(Block(
????????????????outer_dim,?_inner_dim,?outer_head=outer_head,?inner_head=inner_head,
????????????????num_words=num_words,?mlp_ratio=mlp_ratio,?qkv_bias=qkv_bias,?qk_scale=qk_scale,?drop=drop,
????????????????attn_drop=attn_drop,?drop_path=drop_path[j],?act_layer=act_layer,?norm_layer=norm_layer,
????????????????se=se,?sr_ratio=sr_ratio))
????????self.blocks?=?nn.ModuleList(blocks)
????????self.relative_pos?=?nn.Parameter(torch.randn(1,?outer_head,?num_patches,?num_patches?//?sr_ratio?//?sr_ratio))
????def?forward(self,?inner_tokens,?outer_tokens,?H_out,?W_out,?H_in,?W_in):
????????for?blk?in?self.blocks:
????????????inner_tokens,?outer_tokens?=?blk(inner_tokens,?outer_tokens,?H_out,?W_out,?H_in,?W_in,?self.relative_pos)
????????return?inner_tokens,?outer_tokens
????
????
class?PyramidTNT(nn.Module):
????"""?
????PyramidTNT?
????"""
????def?__init__(self,?configs=None,?img_size=224,?in_chans=3,?num_classes=1000,?mlp_ratio=4.,?qkv_bias=False,
????????????????qk_scale=None,?drop_rate=0.,?attn_drop_rate=0.,?drop_path_rate=0.,?norm_layer=nn.LayerNorm,?se=0):
????????super().__init__()
????????self.num_classes?=?num_classes
????????depths?=?configs['depths']
????????outer_dims?=?configs['outer_dims']
????????inner_dims?=?configs['inner_dims']
????????outer_heads?=?configs['outer_heads']
????????inner_heads?=?configs['inner_heads']
????????sr_ratios?=?[4,?2,?1,?1]
????????dpr?=?[x.item()?for?x?in?torch.linspace(0,?drop_path_rate,?sum(depths))]??#?stochastic?depth?decay?rule?
????????self.num_features?=?outer_dims[-1]??#?num_features?for?consistency?with?other?models???????
????????self.patch_embed?=?Stem(
????????????img_size=img_size,?in_chans=in_chans,?outer_dim=outer_dims[0],?inner_dim=inner_dims[0])
????????num_patches?=?self.patch_embed.num_patches
????????num_words?=?self.patch_embed.num_words
????????
????????self.outer_pos?=?nn.Parameter(torch.zeros(1,?num_patches,?outer_dims[0]))
????????self.inner_pos?=?nn.Parameter(torch.zeros(1,?num_words,?inner_dims[0]))
????????self.pos_drop?=?nn.Dropout(p=drop_rate)
????????depth?=?0
????????self.word_merges?=?nn.ModuleList([])
????????self.sentence_merges?=?nn.ModuleList([])
????????self.stages?=?nn.ModuleList([])
????????#?搭建PyramidTNT所需要的4個Stage
????????for?i?in?range(4):
????????????if?i?>?0:
????????????????self.word_merges.append(WordAggregation(inner_dims[i-1],?inner_dims[i],?stride=2))
????????????????self.sentence_merges.append(SentenceAggregation(outer_dims[i-1],?outer_dims[i],?stride=2))
????????????self.stages.append(Stage(depths[i],?outer_dim=outer_dims[i],?inner_dim=inner_dims[i],
????????????????????????outer_head=outer_heads[i],?inner_head=inner_heads[i],
????????????????????????num_patches=num_patches?//?(2?**?i)?//?(2?**?i),?num_words=num_words,?mlp_ratio=mlp_ratio,
????????????????????????qkv_bias=qkv_bias,?qk_scale=qk_scale,?drop=drop_rate,?attn_drop=attn_drop_rate,
????????????????????????drop_path=dpr[depth:depth+depths[i]],?norm_layer=norm_layer,?se=se,?sr_ratio=sr_ratios[i])
????????????)
????????????depth?+=?depths[i]
????????
????????self.norm?=?norm_layer(outer_dims[-1])
????????#?Classifier?head
????????self.head?=?nn.Linear(outer_dims[-1],?num_classes)?if?num_classes?>?0?else?nn.Identity()
????def?forward_features(self,?x):
????????inner_tokens,?outer_tokens,?(H_out,?W_out),?(H_in,?W_in)?=?self.patch_embed(x)
????????inner_tokens?=?inner_tokens?+?self.inner_pos?#?B*N,?8*8,?C
????????outer_tokens?=?outer_tokens?+?self.pos_drop(self.outer_pos)??#?B,?N,?D
????????
????????for?i?in?range(4):
????????????if?i?>?0:
????????????????inner_tokens?=?self.word_merges[i-1](inner_tokens,?H_out,?W_out,?H_in,?W_in)
????????????????outer_tokens,?H_out,?W_out?=?self.sentence_merges[i-1](outer_tokens,?H_out,?W_out)
????????????inner_tokens,?outer_tokens?=?self.stages[i](inner_tokens,?outer_tokens,?H_out,?W_out,?H_in,?W_in)
????????
????????outer_tokens?=?self.norm(outer_tokens)
????????return?outer_tokens.mean(dim=1)
????def?forward(self,?x):
????????#?特征提取層,可以作為Backbone用到下游任務
????????x?=?self.forward_features(x)
????????#?分類層
????????x?=?self.head(x)
????????return?x
2.3 其他的Tricks
除了修改網(wǎng)絡體系結構外,還采用了幾種Vision Transformer的高級技巧。
在自注意力模塊上添加相對位置編碼,以更好地表示Token之間的相對位置。
前兩個階段利用Linear spatial reduction attention(LSRA)來降低長序列自注意力的計算復雜度。
3實驗
3.1 分類
表3顯示了ImageNet-1K分類結果。與原來的TNT相比,PyramidTNT實現(xiàn)了更好的圖像分類精度。例如,與TNT-S相比,使用少1.9B的TNT-S的Top-1精度高0.5%。這里還將PyramidTNT與其他具有代表性的CNN、MLP和基于Transformer的模型進行了比較。從結果中可以看到PyramidTNT是最先進的Vision Transformer。

3.2 目標檢測
表4報告了“1x”訓練計劃下的目標檢測和實例分割的結果。PyramidTNT-S在One-Stage和Two-Stage檢測器上都顯著優(yōu)于其他Backbone,且計算成本相似。例如,基于PyramidTNT-S的RetinaNet達到了42.0 AP和57.7AP-L,分別高出使用Swin-Transformer的模型0.5AP和2.2APL。
這些結果表明,PyramidTNT體系結構可以更好地捕獲大型物體的全局信息。金字塔的簡單的上采樣策略和較小的空間形狀使AP-S從一個大規(guī)模的推廣。

3.3 實例分割
PyramidTNT-S在Mask R-CNN和Cascade Mask R-CNN上的AP-m可以獲得更好的AP-b和AP-m,顯示出更好的特征表示能力。例如,在ParamidTNN約束上,MaskR-CNN-S超過Hire-MLPS 的0.9AP-b。

4參考
[1].PyramidTNT:Improved Transformer-in-Transformer Baselines with Pyramid Architecture
5推薦閱讀

清華大學提出DAT | DCN+Swin Transformer會碰撞出怎樣的火花???

全新Backbone | Pale Transformer完美超越Swin Transformer

激活函數(shù) | Squareplus性能比肩Softplus激活函數(shù)速度快6倍(附Pytorch實現(xiàn))
長按掃描下方二維碼添加小助手。
可以一起討論遇到的問題
聲明:轉載請說明出處
掃描下方二維碼關注【集智書童】公眾號,獲取更多實踐項目源碼和論文解讀,非常期待你我的相遇,讓我們以夢為馬,砥礪前行!

