Swin Transformer的繼任者(上)
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
近期,隨著PVT和Swin Transformer的成功,讓我們看到了將ViT應(yīng)用在dense prediction的backbone的巨大前景。PVT的核心是金字塔結(jié)構(gòu),同時(shí)通過對attention的keys和values進(jìn)行downsample來進(jìn)一步減少計(jì)算量,但是其計(jì)算復(fù)雜度依然和圖像大小()的平成正比。而Swin Transformer在金字塔結(jié)構(gòu)基礎(chǔ)上提出了window attention,這其實(shí)本質(zhì)上是一種local attention,并通過shifted window來建立cross-window的關(guān)系,其計(jì)算復(fù)雜度和圖像大小()成正比。基于local attention的模型計(jì)算復(fù)雜低,但是也喪失了global attention的全局感受野建模能力。近期,在Swin Transformer之后也有一些基于local attention的工作,它們從不同的方面來提升模型的全局建模能力。
Twins
美團(tuán)提出的Twins思路比較簡單,那就是將local attention和global attention結(jié)合在一起。Twins主體也采用金字塔結(jié)構(gòu),但是每個(gè)stage中交替地采用LSA(Locally-grouped self-attention)和GSA(Global sub-sampled attention),這里的LSA其實(shí)就是Swin Transformer中的window attention,而GSA就是PVT中采用的對keys和values進(jìn)行subsapmle的MSA。LSA用來提取局部特征,而GSA用來實(shí)現(xiàn)全局感受野:

此外,Twins還引入了美團(tuán)之前論文CPVT提出的PEG(position encoding generator)來進(jìn)行位置編碼,具體是在每個(gè)stage的第一個(gè)transfomer encoder后插入一個(gè)PEG(具體實(shí)現(xiàn)上是一個(gè)3x3的depth-wise conv)。如果將PVT中的位置編碼用PEG替換(稱為Twins-PCPVT),那么模型效果也有一個(gè)明顯的提升。

同樣地,用了PEG后,可以將window attention中的相對位置編碼也去掉了(相比Swin Transformer),最終的模型稱為Twins-SVT。在224x224輸入的ImageNet數(shù)據(jù)集上,可以看到Twins-SVT分類效果超過了Swin,而且模型參數(shù)和計(jì)算量均更低。

在COCO數(shù)據(jù)集上,基于Mask R-CNN模型,Twins-SVT也比Swin模型效果要好,而且FLOPs更低,不過這是在800x600圖片大小下測試的。畢竟GSA計(jì)算復(fù)雜度還是和圖像大小的平方成正比,當(dāng)圖像輸入原來越大時(shí),Twins-SVT也會像PVT那樣計(jì)算量增加迅速,但是Swin模型是線性增長。

MSG-Transformer
華為提出的MSG-Transformer主要思路是為每個(gè)window增加一個(gè)信使token(messenger token, MSG),這個(gè)不同的windows通過MSG token來建立聯(lián)系,具體的操作是對MSG token進(jìn)行shuffle。下圖中圖像共分為個(gè)windows(綠色線條),而每個(gè)windows組成一個(gè)shuffle region;每個(gè)Window都包含一個(gè)MSG token,經(jīng)過window attention之后,同一個(gè)shuffle region的MSG token將先進(jìn)行shuffle,最后才送入MLP中。

對于一個(gè)shuffle region,這里記其大小為,其MSG tokens組合在一起記為,這里是特征維度大小。MSG token的shuffle可以通過reshape->transpose->reshape來實(shí)現(xiàn):
其實(shí)就是對MSG tokens的特征進(jìn)行shuffle,這樣shuffle后每個(gè)window的MSG token將包含其它windows的部分MSG token特征,從而完成不同windows之間的消息傳遞:

而MSG Transformer主體也采用金字塔結(jié)構(gòu),不同的stage的取值不同,對于分類任務(wù),各個(gè)stage的分別為4,4,2,1。在實(shí)現(xiàn)上,我們可以將同一個(gè)shuffle region區(qū)域放在維度1,而總的shuffle regions和Batch放在第一個(gè)維度,這樣就非常實(shí)現(xiàn)MSG tokens的shuffle:
def window_partition(x, window_size, shuf_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
shuf_size (int): shuffle region size
Returns:
windows: (B*num_region, shuf_size**2, window_size**2, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size // shuf_size, shuf_size, window_size,
W // window_size // shuf_size, shuf_size, window_size, C)
windows = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(-1, shuf_size**2, window_size**2, C)
return windows
def shuffel_msg(x):
# (B, G, win**2+1, C)
B, G, N, C = x.shape
if G == 1:
return x
msges = x[:, :, 0] # (B, G, C)
assert C % G == 0
msges = msges.view(-1, G, G, C//G).transpose(1, 2).reshape(B, G, 1, C)
x = torch.cat((msges, x[:, :, 1:]), dim=2)
return x
MSG Transformer的window attention和Swin Transformer一樣也采用相對位置編碼,但是多了一個(gè)MSG token,所以相對位置編碼多了兩個(gè)參數(shù)(其它patch tokens相對MSG token,MSG token相對其它patch tokens)。另外在每個(gè)stage開始的token merging操作,對MSG token也采取類似的處理:2x2個(gè)windows的MSG token進(jìn)行concat,并進(jìn)行線性變換。
MSG Transformer引入的MSG token對計(jì)算量和模型參數(shù)都影響不大,所以其和Swin Transformer一樣其計(jì)算復(fù)雜度線性于圖像大小。在ImageNet上,其模型效果和Swin接近,但其在CPU上速度較快:

在COCO數(shù)據(jù)集上,基于Mask R-CNN模型,也可以和Swin模型取得類似的效果:

參考
Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer Twins: Revisiting the Design of Spatial Attention in Vision Transformers Glance-and-Gaze Vision Transformer MSG-Transformer: Exchanging Local Spatial Information by Manipulating Messenger Tokens Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight
推薦閱讀
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
"未來"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
一個(gè)用心的公眾號

