Tokens-to-token ViT: 對token做編碼的純transformer ViT,T2T算引入了...
【GaintPandaCV導(dǎo)語】?
T2T-ViT是純transformer的形式,先對原始數(shù)據(jù)做了token編碼后,再堆疊Deep-narrow網(wǎng)絡(luò)結(jié)構(gòu)的transformer模塊,實(shí)際上T2T也引入了CNN。
引言
一句話概括:也是純transformer的形式,先對原始數(shù)據(jù)做了token編碼后,再堆疊Deep-narrow網(wǎng)絡(luò)結(jié)構(gòu)的transformer模塊。對token編碼筆者認(rèn)為本質(zhì)上是做了局部特征提取也就是CNN擅長做的事情。
原論文作者認(rèn)為ViT效果不及CNN的原因:
1、直接將圖像分patch后生成token的方式?jīng)]法建模局部結(jié)構(gòu)特征(local structure),比如相鄰位置的線,邊緣;
2、在限定計(jì)算量和限定訓(xùn)練數(shù)據(jù)數(shù)量的條件下,ViT冗余的注意力骨架網(wǎng)絡(luò)設(shè)計(jì)導(dǎo)致提取不到豐富的特征。
所以針對這倆點(diǎn)就提出兩個解決方法:
1、找一種高效生成token的方法,即 Tokens-to-Token (T2T)
2、設(shè)計(jì)一個新的純transformer的網(wǎng)絡(luò),即deep-narrow,并對比了目前的流行的CNN網(wǎng)絡(luò)。
當(dāng)然對比完后是作者提出的Deep-narrow效果最好。原文的對比實(shí)驗(yàn)值得去借鑒(抄)。
1). 密稠連接,Dense Connection,類比ResNet和DenseNet
2).Deep-narrow 對比shallow-Wide,類比Wide-ResNet
3).通道注意力,類比SE-ResNet
4).在多頭注意力層加入更多頭,類比ResNeXt
5).Ghost操作,即減少conv的輸出通道后再通過DWConv和skip connect將這倆concat起來,類比GhostNet
實(shí)驗(yàn)的結(jié)果:給出來了煉丹配方了,這一點(diǎn)還是很良心的,根據(jù)現(xiàn)有的CNN的模型架構(gòu)特征改造純transformer
Deep-narrow能提高VIT的特征豐富性,模型大小和MACs降低,整體效果也提升了;通道注意力對ViT也有提升,但Deep-narrow結(jié)構(gòu)更加高效;密稠連接會影響性能;
筆者認(rèn)為最重要的token的生成,即可Tokens-to-token模塊。

直接看圖來分析分析,是怎么做T2T的,看上面Firgure 4橘黃色部分。
步驟1:有重疊地取圖像的區(qū)域,實(shí)際上這個區(qū)域就是做卷積的窗口,這個窗口大小是7×7,stride為4,padding為2,然后調(diào)用nn.Unfold函數(shù)將[7,7]攤平成[49](也就是把一張餅變成一長條),其實(shí)也就是img2col,這一步命名為"soft split";
步驟2:對攤平的長條做變換,這里使用了transformer,可以用performer來降低transformer的計(jì)算復(fù)雜度,這一步命名為"re-structurization/reconstruction";
步驟3:將步驟2出來的結(jié)果(B,H×W,C)reshape成一個4維度(B,C,H,W)矩陣;
步驟4:跟步驟1一樣,取一個窗口的數(shù)值,即nn.Unfold,這次窗口是3×3,stride為2,padding為1;
步驟5:跟步驟2一樣,對取到的長條做變換,即可transformer或者performer;
步驟6:跟步驟3一樣,reshape成一個4維度矩陣;
步驟7:跟步驟4一樣,參數(shù)也一樣,取出長條;
步驟8:將步驟7出來的長條做一次全連接生成固定的token數(shù)量。
整個Tokens-to-token就完成了。
代碼及分析
看看代碼:
class?T2T_module(nn.Module):
????"""
????Tokens-to-Token?encoding?module
????"""
????def?__init__(self,?img_size=224,?tokens_type='performer',?in_chans=3,?embed_dim=768,?token_dim=64):
????????super().__init__()
????????if?tokens_type?==?'transformer':
????????????print('adopt?transformer?encoder?for?tokens-to-token')
????????????self.soft_split0?=?nn.Unfold(kernel_size=(7,?7),?stride=(4,?4),?padding=(2,?2))
????????????self.soft_split1?=?nn.Unfold(kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1))
????????????self.soft_split2?=?nn.Unfold(kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1))
????????????self.attention1?=?Token_transformer(dim=in_chans?*?7?*?7,?in_dim=token_dim,?num_heads=1,?mlp_ratio=1.0)
????????????self.attention2?=?Token_transformer(dim=token_dim?*?3?*?3,?in_dim=token_dim,?num_heads=1,?mlp_ratio=1.0)
????????????self.project?=?nn.Linear(token_dim?*?3?*?3,?embed_dim)
????????elif?tokens_type?==?'performer':
????????????print('adopt?performer?encoder?for?tokens-to-token')
????????????self.soft_split0?=?nn.Unfold(kernel_size=(7,?7),?stride=(4,?4),?padding=(2,?2))
????????????self.soft_split1?=?nn.Unfold(kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1))
????????????self.soft_split2?=?nn.Unfold(kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1))
????????????#self.attention1?=?Token_performer(dim=token_dim,?in_dim=in_chans*7*7,?kernel_ratio=0.5)
????????????#self.attention2?=?Token_performer(dim=token_dim,?in_dim=token_dim*3*3,?kernel_ratio=0.5)
????????????self.attention1?=?Token_performer(dim=in_chans*7*7,?in_dim=token_dim,?kernel_ratio=0.5)
????????????self.attention2?=?Token_performer(dim=token_dim*3*3,?in_dim=token_dim,?kernel_ratio=0.5)
????????????self.project?=?nn.Linear(token_dim?*?3?*?3,?embed_dim)
????????elif?tokens_type?==?'convolution':??#?just?for?comparison?with?conolution,?not?our?model
????????????#?for?this?tokens?type,?you?need?change?forward?as?three?convolution?operation
????????????print('adopt?convolution?layers?for?tokens-to-token')
????????????self.soft_split0?=?nn.Conv2d(3,?token_dim,?kernel_size=(7,?7),?stride=(4,?4),?padding=(2,?2))??#?the?1st?convolution
????????????self.soft_split1?=?nn.Conv2d(token_dim,?token_dim,?kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1))?#?the?2nd?convolution
????????????self.project?=?nn.Conv2d(token_dim,?embed_dim,?kernel_size=(3,?3),?stride=(2,?2),?padding=(1,?1))?#?the?3rd?convolution
????????self.num_patches?=?(img_size?//?(4?*?2?*?2))?*?(img_size?//?(4?*?2?*?2))??#?there?are?3?sfot?split,?stride?are?4,2,2?seperately
????def?forward(self,?x):
????????#?step0:?soft?split
????????x?=?self.soft_split0(x).transpose(1,?2)
????????#?iteration1:?re-structurization/reconstruction
????????x?=?self.attention1(x)
????????B,?new_HW,?C?=?x.shape
????????x?=?x.transpose(1,2).reshape(B,?C,?int(np.sqrt(new_HW)),?int(np.sqrt(new_HW)))
????????#?iteration1:?soft?split
????????x?=?self.soft_split1(x).transpose(1,?2)
????????#?iteration2:?re-structurization/reconstruction
????????x?=?self.attention2(x)
????????B,?new_HW,?C?=?x.shape
????????x?=?x.transpose(1,?2).reshape(B,?C,?int(np.sqrt(new_HW)),?int(np.sqrt(new_HW)))
????????#?iteration2:?soft?split
????????x?=?self.soft_split2(x).transpose(1,?2)
????????#?final?tokens
????????x?=?self.project(x)
????????return?x
接下來看怎么對生成的token做transformer,看上面Firgure 4淺灰色部分,也就是堆疊transformer layer,最后加一個MLP做分類。transformer layer就是眾所周知的了。
然后就是怎么做堆疊呢?Deep-narrow的方式,也就是層數(shù)變多,維度變小,“高高瘦瘦”。這部分代碼也眾所周知了,就不貼代碼了。而且個人覺得,雖然作者對Deep-narrow的對比實(shí)驗(yàn)非常豐富,但我個人主觀認(rèn)為,網(wǎng)絡(luò)部分是為了結(jié)合T2T,你用其他網(wǎng)絡(luò)堆疊也是可以的,是一個調(diào)參過程。
這里我有個疑問,所以T2T這一部分跟CNN有什么區(qū)別呢?看看Figure 3。
在這里插入圖片描述我們知道CNN = unfold + matmul + fold。那么T2T模塊第一步做了unfold,然后對取出來的窗口做了transformer的非線性變化,這一步我們是不是可以理解為對窗口里面的像素點(diǎn)做了matmul呢?這里的matmul可能更像是做attention。然后reshape回去相當(dāng)于做了fold操作。筆者認(rèn)為,T2T模塊,本質(zhì)上就是做了局部特征提取,也就CNN擅長做的事情。
個人主觀評價
T2T是一篇好文,應(yīng)該是第一篇提出要對token進(jìn)行處理的ViT工作,本意是為了提取更加高效的token,這樣可以減少token的數(shù)量,那么堆疊transformer模塊也能降低參數(shù)量和計(jì)算量。
但本質(zhì)上還是隱式引入了卷積,即有unfold + matmul + fold = CNN。對比與后來者ViTAE,T2T的解決方法其實(shí)更加簡潔。
