TNT | 致敬Network in Network,華為諾亞提出Transformer-in-Transformer

極市導(dǎo)讀
本文是華為諾亞方舟實(shí)驗(yàn)在Transformer方面的又一次探索,針對(duì)現(xiàn)有Transformer存在打破圖像塊的結(jié)構(gòu)信息的問(wèn)題,提出了一種新穎的同時(shí)進(jìn)行patch與pixel表達(dá)建模的TNT模塊。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿

本文是華為諾亞方舟實(shí)驗(yàn)在Transformer方面的又一次探索,針對(duì)現(xiàn)有Transformer存在打破圖像塊的結(jié)構(gòu)信息的問(wèn)題,提出了一種新穎的同時(shí)進(jìn)行patch與pixel表達(dá)建模的TNT模塊,它包含用于塊嵌入建模的Outer Transformer 模塊與像素嵌入建模的Inner Transformer模塊,通過(guò)這種方式使得TNT可以同時(shí)提取全局與局部結(jié)構(gòu)信息。在ImageNet數(shù)據(jù)集上,TNT-S模型以81.3%的top1精度超過(guò)了DeiT-S的的79.8%;TNT-B以82.8%的top1精度超過(guò)了DeiT-B的81.8%的top1精度。
Abstract
Transformer是一種自注意力機(jī)制神經(jīng)網(wǎng)絡(luò),最早興起于NLP領(lǐng)域。近來(lái),純transformer模型已被提出并用于CV的各個(gè)領(lǐng)域,比如用于low-level問(wèn)題的IPT,detection的DETR,classification的ViT,segmentation的SETR等等。然而這些Visual Transformer通過(guò)將圖像視作塊序列而忽視了它們最本質(zhì)的結(jié)構(gòu)信息。
針對(duì)上述問(wèn)題,我們提出了一種新穎的Transformer iN Transformer(TNT)模型用于對(duì)patch與pixel層面特征建模。在每個(gè)TNT模塊中,outer transformer block用于處理塊嵌入,而inner transformer block用于處理像素嵌入的局部特征,像素級(jí)特征通過(guò)線性變換投影到塊嵌入空間并與塊嵌入相加。通過(guò)堆疊TNT模塊,我們構(gòu)建了TNT模塊用于圖像識(shí)別。
我們?cè)贗mageNet與下游任務(wù)上驗(yàn)證了所提TNT架構(gòu)的優(yōu)越性,比如,在相似計(jì)算復(fù)雜度下,TNT在ImageNet上取得了81.3%的top1精度,以1.5%優(yōu)于DeiT。
Method
接下來(lái),我們將重點(diǎn)描述本文所提TNT架構(gòu)并對(duì)其復(fù)雜度進(jìn)行分析。在正式介紹之前,我們先對(duì)transformer的一些基本概念進(jìn)行簡(jiǎn)單介紹。
Preliminaries
Transformer的基本概念包含MSA(Multi-head Self-Attention)、MLP(Multi-Layer Perceptron)以及LN(Layer Normalization)等。
MSA 在自注意力模塊中,輸入將被線性變換為三部分,即queries, keys, values。其中n表示序列長(zhǎng)度,分別表示輸入、queries以及values的維度。此時(shí)自注意力機(jī)制可以描述如下:
最后,通過(guò)一個(gè)線性層生成最終的輸出。而多頭自注意力會(huì)將queries、keys、values拆分h次分別實(shí)施上述注意力機(jī)制,最后將每個(gè)頭的輸出concat并線性投影得到最后的輸出。
MLP MLP是位于自注意力之間的一個(gè)特征變換模塊,起定義如下:
其中表示激活函數(shù),常用GELU,其他參數(shù)則是全連接層的weight與bias,不再贅述。
LN LN 是確保transformer穩(wěn)定訓(xùn)練與快速手鏈的關(guān)鍵部分,起定義如下:
其中,分別表示特征的均值與標(biāo)準(zhǔn)差,o表示點(diǎn)乘操作,為可學(xué)習(xí)變換參數(shù)。
Transformer in Transformer
給定2D圖像,我們將其均勻的拆分為n塊,其中p表示每個(gè)圖像塊的大小。ViT一文采用了標(biāo)準(zhǔn)transformer處理塊序列,打破了塊間的局部結(jié)構(gòu)關(guān)系,可參考下圖a。

相反,本文提出了Transformer-iN-Transformer結(jié)構(gòu)同時(shí)學(xué)習(xí)圖像的全局與局部信息。在每個(gè)TNT模塊中,每個(gè)塊通過(guò)unfold進(jìn)一步變換到目標(biāo)尺寸,結(jié)合線性投影,塊序列變?yōu)椋?/p>
其中, c表示通道數(shù)量。具體來(lái)說(shuō),我們將每個(gè)塊視作像素嵌入信息:
其中,。
在TNT內(nèi)部,我們具有兩個(gè)數(shù)據(jù)流,一個(gè)用于跨塊操作,一個(gè)用于塊內(nèi)像素操作。對(duì)于像素嵌入,我們采用transformer模塊探索像素之間的相關(guān)性:
其中表示層索引,L表示總共層數(shù)。所有塊張量變換為。它可以視作inner transformer block,表示為,該過(guò)程構(gòu)建了任意兩個(gè)像素之間的相關(guān)性。
在塊層面,我們創(chuàng)建了塊嵌入內(nèi)存以保存塊特征,其中表示類信息,切初始化為0。在每一層,塊張量通過(guò)線性投影變換到塊嵌入空間并與塊嵌入相加:
其中表示flatten操作。然后我們采用標(biāo)準(zhǔn)transformer模塊對(duì)塊嵌入進(jìn)行變換:
該輸出即為outer transformer block',它用于建模塊嵌入之間的相關(guān)性。
總而言之,TNT的輸入與輸出包含像素嵌入與塊嵌入,因此TNT可以表示為:
通過(guò)堆疊L次TNT模塊,我們即可構(gòu)建一個(gè)Transformer-in-Transformer網(wǎng)絡(luò),最后類別token作為圖像特征表達(dá),全連接層用于分類。

除了內(nèi)容/特征信息外,空間信息也是圖像識(shí)別非常重要的因素。對(duì)于塊嵌入與像素嵌入來(lái)說(shuō),我們同時(shí)添加了位置編碼信息,見(jiàn)上圖。這里采用標(biāo)準(zhǔn)1D可學(xué)習(xí)位置編碼信息,具體來(lái)說(shuō),每個(gè)塊被賦予一個(gè)位置編碼:
通過(guò)這種方式,塊位置編碼可以更好的保持全局空間結(jié)構(gòu)信息,而像素位置編碼可以保持局部相對(duì)位置關(guān)系。
Complexity Analysis
對(duì)于標(biāo)準(zhǔn)transformer而言,它包含兩部分:MSA與MLP。MSA的FLOPs如下:
而MLP的FLOPs則為。所以,標(biāo)準(zhǔn)transformer的整體FLOPs如下:
一般來(lái)說(shuō),所以FLOPs可以簡(jiǎn)化為,而參數(shù)量則是
本文所提TNT則包含三部分:inner transformer block, outer transformer block與線性層。的計(jì)算復(fù)雜度分別為,線性層的FLOPS則是。因此TNT的總體FLOPs則表示如下:
類似的TNT的參數(shù)量表示如下:
盡管TNT添加了兩個(gè)額外的成分,但FLOPs提升很小。TNT的Flops大約是標(biāo)準(zhǔn)模塊的1.09x,參數(shù)量大概是1.08x。通過(guò)小幅的參數(shù)量與計(jì)算量提升,所提TNT模塊可以有效的建模局部結(jié)構(gòu)信息并取得精度-復(fù)雜度的均衡。
Network Architecture
在最終網(wǎng)絡(luò)結(jié)構(gòu)配置方面,我們延續(xù)了ViT與DeiT的配置方式。塊大小為,unfold塊大小。下表給出了TNT網(wǎng)絡(luò)的不同大小的配置信息,它們分別包含23.8M核65.6M參數(shù)量,對(duì)應(yīng)的FLOPs分別為5.2B與14.1B(注:輸入圖像尺寸為)。

Operational Optimizations 此外,啟發(fā)與SE,我們進(jìn)行tansformer的通道注意力機(jī)制探索。我們首先對(duì)所有patch/pixel嵌入進(jìn)行平均,然后采用兩層MLP計(jì)算注意力,所的注意力與所有嵌入相乘。SE模塊僅僅帶來(lái)非常少的參數(shù)量,但有助于進(jìn)行通道層面的特征增強(qiáng)。
Experiments
為驗(yàn)證所提方案的有效性,我們?cè)贗mageNet以及其他下游數(shù)據(jù)上進(jìn)行了對(duì)比分析,相關(guān)數(shù)據(jù)信息如下所示。

訓(xùn)練超參方面的配置信息如下所示。

我們先來(lái)看一下TNT、CNN以及其他Transformer在ImageNet上的性能對(duì)比,結(jié)果見(jiàn)下表。

從上表可以看到:
所提TNT模型優(yōu)于其他所有Transformer模塊,TNT-S取得了81.3%top-1精度并以1.5%指標(biāo)優(yōu)于DeiT-S;通過(guò)添加SE模塊,其性能可以進(jìn)一步提升到81.6%top-1。 相比CNN模型,TNT優(yōu)于廣泛采用的ResNet與RegNet。
最后,我們?cè)賮?lái)看一下在下游任務(wù)的遷移效果,結(jié)果見(jiàn)效果。注:所有模型在分辨率進(jìn)行了微調(diào)。

從上表可以看到:
在遷移學(xué)習(xí)方面,TNT取得了比DeiT更優(yōu)的效果; 通過(guò)更高分辨率的微調(diào),TNT-B取得了83.9%的top-1精度。
全文到此結(jié)束,更多消融實(shí)驗(yàn)與分析建議各位同學(xué)查看原文。
推薦閱讀
2021-03-02

2021-01-28

2021-01-24


# CV技術(shù)社群邀請(qǐng)函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)
即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺(jué)開(kāi)發(fā)者互動(dòng)交流~

