淺談Transformer+CNN混合架構(gòu):CMT以及從0-1復(fù)現(xiàn)

極市導(dǎo)讀
本文詳細(xì)講解了華為諾亞與悉尼大學(xué)在Transformer+CNN架構(gòu)混合方面的嘗試,一種同時具有Transformer長距離建模與CNN局部特征提取能力的CMT。并給出了自己從0-1的復(fù)現(xiàn)過程以及是實(shí)驗結(jié)果。 >>加入極市CV技術(shù)交流群,走在計算機(jī)視覺的最前沿
論文鏈接: https://arxiv.org/abs/2107.06263
論文代碼(個人實(shí)現(xiàn)版本): https://github.com/FlyEgle/CMT-pytorch
知乎專欄:https://www.zhihu.com/people/flyegle
寫在前面
本篇博客講解CMT模型并給出從0-1復(fù)現(xiàn)的過程以及實(shí)驗結(jié)果,由于論文的細(xì)節(jié)并沒有給出來,所以最后的復(fù)現(xiàn)和paper的精度有一點(diǎn)差異,等作者release代碼后,我會詳細(xì)的校對我自己的code,找找原因。
1. 出發(fā)點(diǎn)
Transformers與現(xiàn)有的卷積神經(jīng)網(wǎng)絡(luò)(CNN)在性能和計算成本方面仍有差距。 希望提出的模型不僅可以超越典型的Transformers,而且可以超越高性能卷積模型。
2. 怎么做
提出混合模型(串行),通過利用Transformers來捕捉長距離的依賴關(guān)系,并利用CNN來獲取局部特征。 引入depth-wise卷積,獲取局部特征的同時,減少計算量 使用類似R50模型結(jié)構(gòu)一樣的stageblock,使得模型具有下采樣增強(qiáng)感受野和遷移dense的能力。 使用conv-stem來使得圖像的分辨率縮放從VIT的1/16變?yōu)?/4,保留更多的patch信息。
3. 模型結(jié)構(gòu)

(a)表示的是標(biāo)準(zhǔn)的R50模型,具有4個stage,每個都會進(jìn)行一次下采樣。最后得到特征表達(dá)后,經(jīng)過AvgPool進(jìn)行分類 (b)表示的是標(biāo)準(zhǔn)的VIT模型,先進(jìn)行patch的劃分,然后embeeding后進(jìn)入Transformer的block,這里,由于Transformer是long range的,所以進(jìn)入什么,輸出就是什么,引入了一個非image的class token來做分類。 (c)表示的是本文所提出的模型框架CMT,由CMT-stem, downsampling, cmt block所組成,整體結(jié)構(gòu)則是類似于R50,所以可以很好的遷移到dense任務(wù)上去。
3.1. CMT Stem
使用convolution來作為transformer結(jié)構(gòu)的stem,這個觀點(diǎn)FB也有提出一篇paper,Early Convolutions Help Transformers See Better。
https://arxiv.org/abs/2106.14881
CMT&Conv stem共性
使用4層conv3x3+stride2 + conv1x1 stride 1 等價于VIT的patch embeeding,conv16x16 stride 16. 使用conv stem,可以使模型得到更好的收斂,同時,可以使用SGD優(yōu)化器來訓(xùn)練模型,對于超參數(shù)的依賴沒有原始的那么敏感。好處那是大大的多啊,僅僅是改了一個conv stem。
CMT&Conv stem異性
本文僅僅做了一次conv3x3 stride2,實(shí)際上只有一次下采樣,相比conv stem,可以保留更多的patch的信息到下層。
從時間上來說,一個20210628(conv stem), 一個是20210713(CMT stem),存在借鑒的可能性還是比較小的,也說明了conv stem的確是work。
3.2. CMT Block
每一個stage都是由CMT block所堆疊而成的,CMT block由于是transformer結(jié)構(gòu),所以沒有在stage里面去設(shè)計下采樣。每個CMT block都是由Local Perception Unit, Ligntweight MHSA, Inverted Residual FFN這三個模塊所組成的,下面分別介紹:
Local Perception Unit(LPU)

本文的一個核心點(diǎn)是希望模型具有l(wèi)ong-range的能力,同時還要具有l(wèi)ocal特征的能力,所以提出了LPU這個模塊,很簡單,一個3X3的DWconv,來做局部特征,同時減少點(diǎn)計算量,為了讓Transformers的模塊獲取的longrange的信息不缺失,這里做了一個shortcut,公式描述為:
Lightweight MHSA(LMHSA)

MHSA這個不用多說了,多頭注意力,Lightweight這個作用,PVT 曾經(jīng)有提出過,目的是為了降低復(fù)雜度,減少計算量。那本文是怎么做的呢,很簡單,假設(shè)我們的輸入為, 對其分別做一個scale,使用卷積核為,stride為的Depth Wise卷積來做了一次下采樣,得到的shape為,那么對應(yīng)的Q,K,V的shape分別為:
我們知道,在計算MHSA的時候要遵守兩個計算原則:
Q, K的序列dim要一致。 K, V的token數(shù)量要一致。
所以,本文中的MHSA計算公式如下:
Inverted Resdiual FFN(IRFFN)

FFN的這個模塊,其實(shí)和mbv2的block基本上就是一樣的了,不一樣的地方在于,使用的是GELU,采用的也是DW+PW來減少標(biāo)準(zhǔn)卷積的計算量。很簡單,就不多說了,公式如下:
那么我們一個block里面的整體計算公式如下:
3.3 patch aggregation
每個stage都是由上述的多個CMTblock所堆疊而成, 上面也提到了,這里由于是transformer的操作,不會設(shè)計到scale尺度的問題,但是模型需要構(gòu)造下采樣,來實(shí)現(xiàn)層次結(jié)構(gòu),所以downsampling的操作單獨(dú)拎了出來,每個stage之前會做一次卷積核為2x2的,stride為2的卷積操作,以達(dá)到下采樣的效果。
所以,整體的模型結(jié)構(gòu)就一目了然了,假設(shè)輸入為224x224x3,經(jīng)過CMT-STEM和第一次下采樣后,得到了一個56x56的featuremap,然后進(jìn)入stage1,輸出不變,經(jīng)過下采樣后,輸入為28x28,進(jìn)入stage2,輸出后經(jīng)過下采樣,輸入為14x14,進(jìn)入stage3,輸出后經(jīng)過最后的下采樣,輸入為7x7,進(jìn)入stage4,最后輸出7x7的特征圖,后面接avgpool和分類,達(dá)到分類的效果。
我們接下來看一下怎么復(fù)現(xiàn)這篇paper。
4. 論文復(fù)現(xiàn)
ps: 這里的復(fù)現(xiàn)指的是沒有源碼的情況下,實(shí)現(xiàn)網(wǎng)絡(luò),訓(xùn)練等,如果是結(jié)果復(fù)現(xiàn),會標(biāo)明為復(fù)現(xiàn)精度。
這里存在幾個問題
文章的問題:我看到paper的時候,是第一個版本的arxiv,大概過了一周左右V2版本放出來了,這兩個版本有個很大的diff。


網(wǎng)絡(luò)結(jié)構(gòu)可以說完全不同的情況下,F(xiàn)LOPs竟然一樣的,當(dāng)然可能是寫錯了,這里就不吐槽了。不過我一開始代碼復(fù)現(xiàn)就是按下面來的,所以對于我也沒影響多少,只是體驗有點(diǎn)差罷了。
細(xì)節(jié)的問題:paper和很多的transformer一樣,都是采用了Deit的訓(xùn)練策略,但是差別在于別的paper或多或少會給出來額外的tirck,比如最后FC的dp的ratio等,或者會改變一些,再不濟(jì)會把代碼直接release了,所以只好悶頭嘗試Trick。
4.1 復(fù)現(xiàn)難點(diǎn)
paper里面采用的Position Embeeding和Swin是類似的,都是Relation Position Bias,但是和Swin不相同的是,我們的Q,K,V尺度是不一樣的。這里我考慮了兩種實(shí)現(xiàn)方法,一種是直接bicubic插值,另一種則是切片,切片更加直觀且embeeding我設(shè)置的可BP,所以,實(shí)現(xiàn)里面采用的是這種方法,代碼如下:
def generate_relative_distance(number_size):
"""return relative distance, (number_size**2, number_size**2, 2)
"""
indices = torch.tensor(np.array([[x, y] for x in range(number_size) for y in range(number_size)]))
distances = indices[None, :, :] - indices[:, None, :]
distances = distances + number_size - 1 # shift the zeros postion
return distances
...
elf.position_embeeding = nn.Parameter(torch.randn(2 * self.features_size - 1, 2 * self.features_size - 1))
...
q_n, k_n = q.shape[1], k.shape[2]
attn = attn + self.position_embeeding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]][:, :k_n]
4.2 復(fù)現(xiàn)trick歷程(血與淚TT)
一方面想要看一下model是否是work的,一方面想要順便驗證一下DeiT的策略是否真的有效,所以從頭開始做了很多的實(shí)驗,簡單整理如下:
數(shù)據(jù):
訓(xùn)練數(shù)據(jù): 20%的imagenet訓(xùn)練數(shù)據(jù)(快速實(shí)驗)。 驗證數(shù)據(jù): 全量的imagenet驗證數(shù)據(jù)。
環(huán)境:
8xV100 32G CUDA 10.2 + pytorch 1.7.1
sgd優(yōu)化器實(shí)驗記錄

結(jié)論: 可以看到在SGD優(yōu)化器的情況下,使用1.6的LR,訓(xùn)練300個epoch,warmup5個epoch,是用cosine衰減學(xué)習(xí)率的策略,用randaug+colorjitter+mixup+cutmix+labelsmooth,設(shè)置weightdecay為0.1的配置下,使用QKV的bias以及相對位置偏差,可以達(dá)到比baseline高11%個點(diǎn)的結(jié)果,所有的實(shí)驗都是用FP16跑的。
adamw優(yōu)化器實(shí)驗記錄

結(jié)論:使用AdamW的情況下,對學(xué)習(xí)率的縮放則是以512的bs為基礎(chǔ),所以對于4k的bs情況下,使用的是4e-3的LR,但是實(shí)驗發(fā)現(xiàn)增大到6e-3的時候,還會帶來一些提升,同時放大一點(diǎn)weightsdecay,也略微有所提升,最終使用AdamW的配置為,6e-3的LR,1e-1的weightdecay,和sgd一樣的增強(qiáng)方法,然后加上了隨機(jī)深度失活設(shè)置,最后比baseline高了16%個點(diǎn),比SGD最好的結(jié)果要高0.8%個點(diǎn)。
4.3. imagenet上的結(jié)果

最后用全量跑,使用SGD會報nan的問題,我定位了一下發(fā)現(xiàn),running_mean和running_std有nan出現(xiàn),本以為是數(shù)據(jù)增強(qiáng)導(dǎo)致的0或者nan值出現(xiàn),結(jié)果空跑幾次數(shù)據(jù)發(fā)現(xiàn)沒問題,只好把優(yōu)化器改成了AdamW,結(jié)果上述所示,CMT-Tiny在160x160的情況下達(dá)到了75.124%的精度,相比MbV2,MbV3的確是一個不錯的精度了,但是相比paper本身的精度還是差了將近4個點(diǎn),很是離譜。
速度上,CMT雖然FLOPs低,但是實(shí)際的推理速度并不快,128的bs條件下,速度慢了R50將近10倍。
5. 實(shí)驗結(jié)果
總體來說,CMT達(dá)到了更小的FLOPs同時有著不錯的精度, imagenet上的結(jié)果如下:

coco2017上也有這不錯的精度

6. 結(jié)論
本文提出了一種名為CMT的新型混合架構(gòu),用于視覺識別和其他下游視覺任務(wù),以解決在計算機(jī)視覺領(lǐng)域以粗暴的方式利用Transformers的限制。所提出的CMT同時利用CNN和Transformers的優(yōu)勢來捕捉局部和全局信息,促進(jìn)網(wǎng)絡(luò)的表示能力。在ImageNet和其他下游視覺任務(wù)上進(jìn)行的大量實(shí)驗證明了所提出的CMT架構(gòu)的有效性和優(yōu)越性。
代碼復(fù)現(xiàn)repo:
https://github.com/FlyEgle/CMT-pytorch
實(shí)現(xiàn)不易,求個star!
本文亮點(diǎn)總結(jié)
使用4層conv3x3+stride2 + conv1x1 stride 1 等價于VIT的patch embeeding,conv16x16 stride 16. 使用conv stem,可以使模型得到更好的收斂,同時,可以使用SGD優(yōu)化器來訓(xùn)練模型,對于超參數(shù)的依賴沒有原始的那么敏感。好處那是大大的多啊,僅僅是改了一個conv stem。
如果覺得有用,就請分享到朋友圈吧!
公眾號后臺回復(fù)“CVPR21檢測”獲取CVPR2021目標(biāo)檢測論文下載~

# CV技術(shù)社群邀請函 #

備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)
即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~

