解析 Vision Transformer
【GiantPandaCV導(dǎo)語】Vision Transformer將CV和NLP領(lǐng)域知識結(jié)合起來,對原始圖片進(jìn)行分塊,展平成序列,輸入進(jìn)原始Transformer模型的編碼器Encoder部分,最后接入一個全連接層對圖片進(jìn)行分類。在大型數(shù)據(jù)集上表現(xiàn)超過了當(dāng)前SOTA模型
前言
當(dāng)前Transformer模型被大量應(yīng)用在NLP自然語言處理當(dāng)中,而在計算機(jī)視覺領(lǐng)域,Transformer的注意力機(jī)制attention也被廣泛應(yīng)用,比如Se模塊,CBAM模塊等等注意力模塊,這些注意力模塊能夠幫助提升網(wǎng)絡(luò)性能。而我們的工作展示了不需要依賴CNN的結(jié)構(gòu),也可以在圖像分類任務(wù)上達(dá)到很好的效果,并且也十分適合用于遷移學(xué)習(xí)。
這里的代碼引用自 https://github.com/lucidrains/vit-pytorch,大家有興趣也可以跑跑demo。
方法
首先結(jié)構(gòu)上,我們采取的是原始Transformer模型,方便開箱即用。
如果對Transformer模型不太了解的可以參考這篇文章 解析Transformer模型
整體結(jié)構(gòu)如下

數(shù)據(jù)處理部分
原始輸入的圖片數(shù)據(jù)是 H x W x C,我們先對圖片作分塊,再進(jìn)行展平。假設(shè)每個塊的長寬為(P, P),那么分塊的數(shù)目為
然后對每個圖片塊展平成一維向量,每個向量大小為
總的輸入變換為
這里的代碼如下:
x?=?rearrange(img,?'b?c?(h?p1)?(w?p2)?->?b?(h?w)?(p1?p2?c)',?p1=p,?p2=p)
它使用的是一個einops的拓展包,完成了上述的變換工作
Patch Embedding
接著對每個向量都做一個線性變換(即全連接層),壓縮維度為D,這里我們稱其為 Patch Embedding。
在代碼里是初始化一個全連接層,輸出維度為dim,然后將分塊后的數(shù)據(jù)輸入
self.patch_to_embedding?=?nn.Linear(patch_dim,?dim)
#?forward前向代碼
x?=?rearrange(img,?'b?c?(h?p1)?(w?p2)?->?b?(h?w)?(p1?p2?c)',?p1=p,?p2=p)
x?=?self.patch_to_embedding(x)
Positional Encoding
還記得在解析Transformer那篇文章內(nèi)有說過,原始的Transformer引入了一個 Positional encoding 來加入序列的位置信息,同樣在這里也引入了pos_embedding,是用一個可訓(xùn)練的變量替代。
self.pos_embedding?=?nn.Parameter(torch.randn(1,?num_patches?+?1,?dim))
文章也提供了可視化圖

很有意思的是這里第二個維度多加了個1。下面會有講到
class_token
這里我們再來仔細(xì)看上圖的一個結(jié)構(gòu)

假設(shè)我們按照論文切成了9塊,但是在輸入的時候變成了10個向量。這是人為增加的一個向量。
因為傳統(tǒng)的Transformer采取的是類似seq2seq編解碼的結(jié)構(gòu) 而ViT只用到了Encoder編碼器結(jié)構(gòu),缺少了解碼的過程,假設(shè)你9個向量經(jīng)過編碼器之后,你該選擇哪一個向量進(jìn)入到最后的分類頭呢?因此這里作者給了額外的一個用于分類的向量,與輸入進(jìn)行拼接。同樣這是一個可學(xué)習(xí)的變量。
具體操作如下
#?假設(shè)dim=128,這里shape為(1,?1,?128)
self.cls_token?=?nn.Parameter(torch.randn(1,?1,?dim))
#?forward前向代碼
#?假設(shè)batchsize=10,這里shape為(10,?1,?128)
cls_tokens?=?repeat(self.cls_token,?'()?n?d?->?b?n?d',?b=b)
#?跟前面的分塊為x(10,64,?128)的進(jìn)行concat
#?得到(10,?65,?128)向量
x?=?torch.cat((cls_tokens,?x),?dim=1)
知道這個操作,我們也就能明白為什么前面的pos_embedding的第一維也要加1了,后續(xù)將pos_embedding也加入到x
?x?+=?self.pos_embedding[:,?:(n?+?1)]
分類
分類頭很簡單,加入了LayerNorm和兩層全連接層實(shí)現(xiàn)的,采用的是GELU激活函數(shù)。代碼如下
self.mlp_head?=?nn.Sequential(
????????????nn.LayerNorm(dim),
????????????nn.Linear(dim,?mlp_dim),
????????????nn.GELU(),
????????????nn.Dropout(dropout),
????????????nn.Linear(mlp_dim,?num_classes)
????????)
最終分類我們只取第一個,也就是用于分類的token,輸入到分類頭里,得到最后的分類結(jié)果
self.to_cls_token?=?nn.Identity()
#?forward前向部分
x?=?self.transformer(x,?mask)
x?=?self.to_cls_token(x[:,?0])
return?self.mlp_head(x)
可以看到整個流程是非常簡單的,下面是ViT的整體代碼
class?ViT(nn.Module):
????def?__init__(self,?*,?image_size,?patch_size,?num_classes,?dim,?depth,?heads,?mlp_dim,?channels=3,?dropout=0.,
?????????????????emb_dropout=0.):
????????super().__init__()
????????assert?image_size?%?patch_size?==?0,?'Image?dimensions?must?be?divisible?by?the?patch?size.'
????????num_patches?=?(image_size?//?patch_size)?**?2
????????patch_dim?=?channels?*?patch_size?**?2
????????assert?num_patches?>?MIN_NUM_PATCHES,?f'your?number?of?patches?({num_patches})?is?way?too?small?for?attention?to?be?effective?(at?least?16).?Try?decreasing?your?patch?size'
????????self.patch_size?=?patch_size
????????self.pos_embedding?=?nn.Parameter(torch.randn(1,?num_patches?+?1,?dim))
????????self.patch_to_embedding?=?nn.Linear(patch_dim,?dim)
????????self.cls_token?=?nn.Parameter(torch.randn(1,?1,?dim))
????????self.dropout?=?nn.Dropout(emb_dropout)
????????self.transformer?=?Transformer(dim,?depth,?heads,?mlp_dim,?dropout)
????????self.to_cls_token?=?nn.Identity()
????????self.mlp_head?=?nn.Sequential(
????????????nn.LayerNorm(dim),
????????????nn.Linear(dim,?mlp_dim),
????????????nn.GELU(),
????????????nn.Dropout(dropout),
????????????nn.Linear(mlp_dim,?num_classes)
????????)
????def?forward(self,?img,?mask=None):
????????p?=?self.patch_size
????????x?=?self.patch_to_embedding(x)
????????b,?n,?_?=?x.shape
????????cls_tokens?=?repeat(self.cls_token,?'()?n?d?->?b?n?d',?b=b)
????????x?=?torch.cat((cls_tokens,?x),?dim=1)
????????x?+=?self.pos_embedding[:,?:(n?+?1)]
????????x?=?self.dropout(x)
????????x?=?self.transformer(x,?mask)
????????x?=?self.to_cls_token(x[:,?0])
????????return?self.mlp_head(x)
實(shí)驗部分
與Transformer一樣,ViT也有規(guī)模不一樣的模型設(shè)置,如下圖所示

可以看到整體模型還是挺大的,而經(jīng)過大數(shù)據(jù)集的預(yù)訓(xùn)練后,性能也超過了當(dāng)前CNN的一些SOTA結(jié)果

另外作者還給了注意力觀察得到的圖片塊,我的一點(diǎn)猜想是可能有利于對神經(jīng)網(wǎng)絡(luò)可解釋性的研究。

總結(jié)
繼DETR后,這又是一個CV和NLP結(jié)合的工作。思想非常的樸素簡單,就是拿最原始的Transformer模型來做圖像分類。現(xiàn)有的性能還需要大量的數(shù)據(jù)來訓(xùn)練,期待后續(xù)工作對ViT做一些改進(jìn),降低其訓(xùn)練時間和所需數(shù)據(jù)量,讓人人都能玩得起ViT!
歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識。( ? ?ω?? )?
有對文章相關(guān)的問題,或者想要加入交流群,歡迎添加BBuf微信:
為了方便讀者獲取資料以及我們公眾號的作者發(fā)布一些Github工程的更新,我們成立了一個QQ群,二維碼如下,感興趣可以加入。
