Tansformer | 詳細(xì)解讀:如何在CNN模型中插入Transformer后速度不變精度劇增?


1簡(jiǎn)介
本文工作解決了Multi-Head Self-Attention(MHSA)中由于計(jì)算/空間復(fù)雜度高而導(dǎo)致的vision transformer效率低的缺陷。為此,作者提出了分層的MHSA(H-MHSA),其表示以分層的方式計(jì)算。
具體來說,H-MHSA首先通過把圖像patch作為tokens來學(xué)習(xí)小網(wǎng)格內(nèi)的特征關(guān)系。然后將小網(wǎng)格合并到大網(wǎng)格中,通過將上一步中的每個(gè)小網(wǎng)格作為token來學(xué)習(xí)大網(wǎng)格中的特征關(guān)系。這個(gè)過程多次迭代以逐漸減少token的數(shù)量。
H-MHSA模塊很容易插入到任何CNN架構(gòu)中,并且可以通過反向傳播進(jìn)行訓(xùn)練。作者稱這種新的Backbone為TransCNN,它本質(zhì)上繼承了transformer和CNN的優(yōu)點(diǎn)。實(shí)驗(yàn)證明,TransCNN在圖像識(shí)別中具有最先進(jìn)的準(zhǔn)確性。
2Vision Transformer回顧
大家應(yīng)該都很清楚Transformer嚴(yán)重依賴MHSA來建模長時(shí)間依賴關(guān)系。假設(shè)為輸入,其中N和C分別為Token的數(shù)量和每個(gè)Token的特征維數(shù)。這里定義了Query 、key 和 value ,其中, , 為線性變換的權(quán)重矩陣。在假設(shè)輸入和輸出具有相同維度的情況下,傳統(tǒng)的MHSA可以表示為:

其中表示近似歸一化,對(duì)矩陣行應(yīng)用Softmax函數(shù)。注意,為了簡(jiǎn)單起見在這里省略了多個(gè)Head的概念。在上式中的矩陣乘積首先計(jì)算每對(duì)Token之間的相似度。然后,在所有Token的組合之上派生出每個(gè)新Token。MHSA計(jì)算后,進(jìn)一步添加殘差連接以方便優(yōu)化,如:

其中,為特征映射的權(quán)重矩陣。最后,采用MLP層增強(qiáng)表示,表示形式為:

其中Y表示transformer block的輸出。
有前面的等式可以得到MHSA的計(jì)算復(fù)雜度:

很容易推斷出空間復(fù)雜度(內(nèi)存消耗)。對(duì)于高分辨率的輸入,可能變得非常大,這限制了Transformer在視覺任務(wù)中的適用性?;诖?,本文的目標(biāo)是在不降低性能的情況下降低這種復(fù)雜性,并保持全局關(guān)系建模的能力。
Transformer Block Pytorch實(shí)現(xiàn)如下:
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
# Muliti-Head Self-Attention Block
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# 輸出 Q K V
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# q matmul k.T
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# attn' matmul v ==> output
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
# Transformer Encoder Block
# Embedded Patches ==> Layer Norm ==> Muliti-Head Attention + ==> Layer Norm ==> MLP + ==>
# |_________________________________________| |__________________|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# 進(jìn)行稀疏化操作,可以得到更好的結(jié)果
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
3Hierarchical Multi-Head Self-Attention
在這里,作者介紹了如何使用H-MHSA降低MHSA的計(jì)算/空間復(fù)雜度。這里不是在整個(gè)輸入中計(jì)算注意力,而是以分層的方式計(jì)算,這樣每個(gè)步驟只處理有限數(shù)量的Token。

圖b為H-MHSA的范式。假設(shè)輸入特征映射的高度為,寬度為,有。然后將特征圖劃分為大小為的小網(wǎng)格,并將特征圖Reshape為:

當(dāng), 和時(shí),式(1)生成局部注意。為了簡(jiǎn)化網(wǎng)絡(luò)優(yōu)化,這里將 Reshape為X的shape:

并添加一個(gè)殘差連接:

由于是在每個(gè)小網(wǎng)格內(nèi)計(jì)算的,因此計(jì)算/空間復(fù)雜度顯著降低。
對(duì)于第i步(i>0),將第(i-1)步處的每個(gè)更小的網(wǎng)格視為一個(gè)Token,這可以簡(jiǎn)單地通過對(duì)注意力特征進(jìn)行降采樣來實(shí)現(xiàn):

其中和分別表示使用最大池化和平均池化(內(nèi)核大小和步長為)將樣本降為次。因此,有, 其中,。然后,將劃分為網(wǎng)格,并將其Reshape為:

當(dāng), , 時(shí),方程(1)獲取注意特征。最終被Reshape為為輸入的shape,比如:

并添加一個(gè)殘差連接:

這個(gè)過程不斷迭代,直到足夠小而不能在進(jìn)行split。H-MHSA的最終輸出為:

如果Upsample(·)表示將注意力特征上采樣到原始大小,則與Equ(2)含義相同, M為最大步數(shù)。通過這種方式,H-MHSA可以等價(jià)于傳統(tǒng)的MHSA來模擬全局關(guān)系。
很容易證明,在所有都相同的假設(shè)下,H-MHSA的計(jì)算復(fù)雜度近似:

與MHSA的計(jì)算復(fù)雜度相比較,本文所提方法顯著降低了計(jì)算復(fù)雜度。
4將Transformer插入到CNN中
本文和之前將CNN與Transformer的方法一樣遵循普遍做法,在網(wǎng)絡(luò)Backbone中保留3D特征圖,并使用全局平均池化層和全連接層來預(yù)測(cè)圖像類別。這與現(xiàn)有的依賴另一個(gè)1D類標(biāo)記進(jìn)行預(yù)測(cè)的Transformer不同。
作者還觀察到以往的Transformer網(wǎng)絡(luò)通常采用GELU函數(shù)進(jìn)行非線性激活。然而,在網(wǎng)絡(luò)訓(xùn)練中,GELU函數(shù)非常耗費(fèi)內(nèi)存。作者通過經(jīng)驗(yàn)發(fā)現(xiàn),SiLU的功能與GELUs不相上下,而且更節(jié)省內(nèi)存。因此,TransCNN選擇使用SiLU函數(shù)進(jìn)行非線性激活。
作者做了一組實(shí)驗(yàn)。在ImageNet驗(yàn)證集上,當(dāng)訓(xùn)練為100個(gè)epoch時(shí),提出的具有SiLU的跨網(wǎng)絡(luò)網(wǎng)絡(luò)(TransCNN)在ImageNet驗(yàn)證集上獲得80.1%的top-1精度。GELU的TransCNN得到79.7%的top-1精度,略低于SiLU。當(dāng)每個(gè)GPU的batchsize=128時(shí),SiLU在訓(xùn)練階段占用20.2GB的GPU內(nèi)存,而GELU占用23.8GB的GPU內(nèi)存。

TransCNN的總體架構(gòu)如圖所示。
在TransCNN的開始階段使用了2個(gè)連續(xù)的個(gè)卷積,每個(gè)卷積的步長為2,將輸入圖像降采樣到1/4的尺度。

然后,將H-MHSA和卷積塊交替疊加,將其分為4個(gè)階段,分別以1/4,1/8,1/16,1/32的金字塔特征尺度進(jìn)行劃分。這里采用的卷積模塊是廣泛使用的Inverted Residual Bottleneck(IRB,圖c),卷積是深度可分離卷積。

在每個(gè)階段的末尾,作者設(shè)計(jì)了一個(gè)簡(jiǎn)單的二分支降采樣塊(TDB,圖d)。它由2個(gè)分支組成:一個(gè)分支是一個(gè)典型的卷積,步長為2;另一個(gè)分支是池化層和卷積。在特征降采樣中,這2個(gè)分支通過元素求和的方式融合,以保留更多的上下文信息。實(shí)驗(yàn)表明,TDB的性能優(yōu)于直接降采樣。

TransCNN的詳細(xì)配置如表所示。提供了2個(gè)版本的TransCNN: TransCNN-Small和TransCNN-Base。TransCNN-Base的參數(shù)個(gè)數(shù)與ResNet50相似。需要注意的是,這里只采用了最簡(jiǎn)單的參數(shù)設(shè)置,沒有進(jìn)行仔細(xì)的調(diào)優(yōu),以證明所提概念H-MHSA和trannn的有效性和通用性。例如,作者使用典型的通道數(shù),即64、128、256和512。MHSA中每個(gè)Head的尺寸被設(shè)置為64。作者提到對(duì)這些參數(shù)設(shè)置進(jìn)行細(xì)致的工程調(diào)整可以進(jìn)一步提高性能。
5實(shí)驗(yàn)
5.1 ImageNet圖像分類

通過上表可以看出,將H-MHSA插入到相應(yīng)的卷積模型中,可以以很少的參數(shù)量和FLOPs換取很大的精度提升。
5.2 MS-COCO 2017目標(biāo)檢測(cè)

通過上表可以看出,在比ResNet50更少的參數(shù)量的同時(shí),RetinaNet的AP得到了很大的提升。
5.3 MS-COCO 2017語義分割
通過上表可以看出,在比ResNet50更少的參數(shù)量的同時(shí),Mask R-CNN的AP得到了很大的提升??梢姳疚乃岱椒ǖ膶?shí)用性還是很強(qiáng)的。
6參考
[1].Transformer in Convolutional Neural Networks
7推薦閱讀

最強(qiáng)Transformer | 太頂流!Scaling ViT將ImageNet Top-1 Acc刷到90.45%啦?。?!

Transformer | 沒有Attention的Transformer依然是頂流?。?!

YOLO |多域自適應(yīng)MSDA-YOLO解讀,惡劣天氣也看得見(附論文)
本文論文原文獲取方式,掃描下方二維碼
回復(fù)【ViT-in-CNN】即可獲取論文
長按掃描下方二維碼添加小助手。
可以一起討論遇到的問題
聲明:轉(zhuǎn)載請(qǐng)說明出處
掃描下方二維碼關(guān)注【集智書童】公眾號(hào),獲取更多實(shí)踐項(xiàng)目源碼和論文解讀,非常期待你我的相遇,讓我們以夢(mèng)為馬,砥礪前行!

