網(wǎng)絡(luò)架構(gòu)設(shè)計:CNN based和Transformer based
點擊上方“AI算法與圖像處理”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
來源:Smarter
從DETR到ViT等工作都驗證了Transformer在計算機視覺領(lǐng)域的潛力,那么很自然的就需要考慮一個新的問題,圖像的特征提取,究竟是CNN好還是Transformer好?
其中CNN的優(yōu)勢在于參數(shù)共享,關(guān)注local信息的聚合,而Transformer的優(yōu)勢在于全局感受野,關(guān)注global信息的聚合。直覺上來講global和local的信息聚合都是有用的,將global信息聚合和local信息聚合有效的結(jié)合在一起可能是設(shè)計最佳網(wǎng)絡(luò)架構(gòu)的正確方向。
如何有效的結(jié)合global和local信息,最近的幾篇文章主要分成了兩個方向:CNN based和Transformer based。以下主要解析一下CNN based和Transformer based的網(wǎng)絡(luò)架構(gòu)設(shè)計,其中CNN based涉及ResNet和BoTNet,Transformer based涉及ViT和T2T-ViT。
01
網(wǎng)絡(luò)架構(gòu)設(shè)計的相互關(guān)系

BoTNet在ResNet的基礎(chǔ)上將Bottlenneck的3x3卷積替換成MHSA,增加CNN based的網(wǎng)絡(luò)架構(gòu)的global信息聚合能力。T2T-ViT在ViT的基礎(chǔ)上將patch的linear projection替換成T2T,增加Transformer based的網(wǎng)絡(luò)架構(gòu)的local信息聚合能力。
02
ResNet&BoTNet

ResNet的結(jié)構(gòu)設(shè)計,ResNet主要由Bottleneck結(jié)構(gòu)堆疊而成,一層Bottlenneck由1x1conv、3x3conv和1x1conv堆疊構(gòu)成殘差分支,然后和skip connect分支相加。BoTNet在Bottlenneck結(jié)構(gòu)的基礎(chǔ)上將中間的3x3conv替換成MHSA結(jié)構(gòu),跟之間的Non-local等工作非常相似,本質(zhì)上在CNN中引入global信息聚合。

MHSA結(jié)構(gòu)如上圖所示,代碼如下。
class MHSA(nn.Module):
? def __init__(self, n_dims, width=14, height=14):
? ? ? super(MHSA, self).__init__()
? ? ? self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
? ? ? self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
? ? ? self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
? ? ? self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True)
? ? ? self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True)
? ? ? self.softmax = nn.Softmax(dim=-1)
? def forward(self, x):
? ? ? n_batch, C, width, height = x.size()
? ? ? q = self.query(x).view(n_batch, C, -1)
? ? ? k = self.key(x).view(n_batch, C, -1)
? ? ? v = self.value(x).view(n_batch, C, -1)
? ? ? content_content = torch.bmm(q.permute(0, 2, 1), k)
? ? ? content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
? ? ? content_position = torch.matmul(content_position, q)
? ? ? energy = content_content + content_position
? ? ? attention = self.softmax(energy)
? ? ? out = torch.bmm(v, attention.permute(0, 2, 1))
? ? ? out = out.view(n_batch, C, width, height)
? ? ? return out
跟Transformer中的multi-head self-attention非常相似,區(qū)別在于MSHA將position encoding當成了spatial attention來處理,嵌入兩個可學(xué)習(xí)的向量看成是橫縱兩個維度的空間注意力,然后將相加融合后的空間向量于q相乘得到contect-position(相當于是引入了空間先驗),將content-position和content-content相乘得到空間敏感的相似性feature,讓MHSA關(guān)注合適區(qū)域,更容易收斂。另外一個不同之處是MHSA只在藍色塊部分引入multi-head。
03
ViT
ViT是第一篇純粹的將Transformer用于圖像特征抽取的文章。

Vision Transformer(ViT)將輸入圖片拆分成16x16個patches,每個patch做一次線性變換降維同時嵌入位置信息,然后送入Transformer。類似BERT[class]標記位的設(shè)置,ViT在Transformer輸入序列前增加了一個額外可學(xué)習(xí)的[class]標記位,并且該位置的Transformer Encoder輸出作為圖像特征。
假設(shè)輸入圖片大小是256x256,打算分成64個patch,每個patch是32x32像素。
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
# 將3072變成dim,假設(shè)是1024
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)
這個寫法是采用了愛因斯坦表達式,具體是采用了einops庫實現(xiàn),內(nèi)部集成了各種算子,rearrange就是其中一個,非常高效。p就是patch大小,假設(shè)輸入是b,3,256,256,則rearrange操作是先變成(b,3,8x32,8x32),最后變成(b,8x8,32x32x3)即(b,64,3072),將每張圖片切分成64個小塊,每個小塊長度是32x32x3=3072,也就是說輸入長度為64的圖像序列,每個元素采用3072長度進行編碼??紤]到3072有點大,ViT使用linear projection對圖像序列編碼進行降維。
04
T2T-ViT

ViT雖然驗證了Transformer在圖像分類網(wǎng)絡(luò)架構(gòu)設(shè)計的潛力,但是需要額外的大規(guī)模數(shù)據(jù)來進行pre-train,而在中等規(guī)模數(shù)據(jù)集如imagenet上效果卻不理想。T2T-ViT引入了local的信息聚合來增強ViT局部結(jié)構(gòu)建模的能力,使得T2T-ViT在中等規(guī)模imagenet上訓(xùn)練能達到更高的精度。
在T2T模塊中,先將輸入圖像軟分割為小塊,然后將其展開成一個tokens T0序列。然后tokens的長度在T2T模塊中逐步減少(文章中使用兩次迭代然后輸出Tf)。后續(xù)跟ViT基本上一致。

一次迭代T2T結(jié)構(gòu)由re-structurization和soft split構(gòu)成,re-structurization將一維序列reshape成二維圖像, soft split對二維圖像進行滑窗操作,拆分成重疊塊。
以token transformer為例,先將輸入圖像拆分成7x7的重疊塊,然后通過token transformer,進行塊內(nèi)的global信息聚合,然后通過re-structurization和soft split進行token重組和拆分成3x3的重疊塊,得到長度更短的token序列,重復(fù)迭代兩次,最后linear projection進一步降低token序列長度。
class T2T_module(nn.Module):
? """
? Tokens-to-Token encoding module
? """
? def __init__(self, img_size=224, in_chans=3, embed_dim=768, token_dim=64):
? ? ? super().__init__()
? ? ? self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
? ? ? self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
? ? ? self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
? ? ? self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
? ? ? self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
? ? ? self.project = nn.Linear(token_dim * 3 * 3, embed_dim)
? ? ? self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 soft split, stride are 4,2,2 seperately
? def forward(self, x):
? ? ? # step0: soft split
? ? ? x = self.soft_split0(x).transpose(1, 2)
? ? ? # iteration1: restricturization/reconstruction
? ? ? x = self.attention1(x)
? ? ? B, new_HW, C = x.shape
? ? ? x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
? ? ? # iteration1: soft split
? ? ? x = self.soft_split1(x).transpose(1, 2)
? ? ? # iteration2: restricturization/reconstruction
? ? ? x = self.attention2(x)
? ? ? B, new_HW, C = x.shape
? ? ? x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
? ? ? # iteration2: soft split
? ? ? x = self.soft_split2(x).transpose(1, 2)
? ? ? # final tokens
? ? ? x = self.project(x)
? ? ? return x
05
總結(jié)
1.global和local信息聚合的關(guān)系
global和local應(yīng)該相互補充來同時balance 速度和精度,同時提升速度和精度的上限
2.CNN based和Transformer based的關(guān)系,CNN based 和 Transformer based哪個好
本質(zhì)上是網(wǎng)絡(luò)架構(gòu)設(shè)計是以CNN為主好還是Transformer為主好的問題,CNN為主還是將輸入當成二維的圖像信號來處理,Transformer為主則將輸入當成一維的序列信號來處理,所以想要研究清楚CNN為主好還是Transformer為主好的問題,需要去探索哪種輸入信號更加具有優(yōu)勢,之前不少研究都表明CNN的padding可能透露了位置信息,而Transformer因為沒有歸納偏見,需要增加position encoding來引入位置信息。CNN為主和Transformer為主各有優(yōu)劣,目前來看暫無定論,且看后續(xù)發(fā)展。
Reference
[1] Deep Residual Learning for Image Recognition
[2]?Bottleneck Transformers for Visual Recognition
[3] An image is worth 16x16 words: Transformers for image recognition at scale
[4] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
個人微信(如果沒有備注不拉群!) 請注明:地區(qū)+學(xué)校/企業(yè)+研究方向+昵稱
下載1:何愷明頂會分享
在「AI算法與圖像處理」公眾號后臺回復(fù):何愷明,即可下載。總共有6份PDF,涉及 ResNet、Mask RCNN等經(jīng)典工作的總結(jié)分析
下載2:終身受益的編程指南:Google編程風(fēng)格指南
在「AI算法與圖像處理」公眾號后臺回復(fù):c++,即可下載。歷經(jīng)十年考驗,最權(quán)威的編程規(guī)范!
下載3 CVPR2020 在「AI算法與圖像處理」公眾號后臺回復(fù):CVPR2020,即可下載1467篇CVPR?2020論文
覺得不錯就點亮在看吧

