Vision Transformer | 超詳解+個人心得
?戳我,查看GAN的系列專輯~!地址:https://zhuanlan.zhihu.com/p/435636952
論文名稱:《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》
論文地址:https://arxiv.org/pdf/2010.11929.pdf
pytorch版本代碼:https://github.com/lucidrains/vit-pytorch
01
這周開始閱讀VIT,讀完后頗有感觸,在這里寫下一些對論文的理解以及個人思考。
We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks
本文是基于transformer的encoder部分提出的針對圖像分類任務(wù)的方法,關(guān)于傳統(tǒng)transformer講解可見本人另一拙作:《attention is all your need》
02
首先放圖:

1.1 數(shù)據(jù)預(yù)處理
從圖片的左下角開始看起,我們看到的是一個個被切分好的圖片塊,這里需要對輸入作出解釋:
假設(shè)原始輸入的圖片數(shù)據(jù)是 H x W x C,我們需要對圖片進行塊切割,假設(shè)圖片塊大小為P1 x P2,則最終的塊數(shù)量N為:N = (H/P1)x(W/P2)。
這里需要注意H和W必須是能夠被P整除的
接下來到了圖一正中間的最下面,我們看到圖片塊被拉成一個線性排列的序列,也就是“一維”的存在(以此來模擬transformer中輸入的詞序列,即我們可以把一個圖片塊看做一個詞),即將切分好的圖片塊進行一個展平操作,那么每一個向量的長度為:Patch_dim = P1 x P2 x C。
經(jīng)過上述兩步操作后,我們得到了一個N x Patch_dim的輸入序列。
1.2 Patch + Position Embedding
僅僅拉平成P1 x P2 x C的向量是不夠的,我們需要經(jīng)過一個全連接層,對維度進行縮放,即文中的Patch Embedding,縮放后的維度為dim(使用nn.Linear即可,此處不再贅述),用公式表示即:

從公式中可以看出多了一個?
?
?
這里用一張圖來幫助理解:

經(jīng)過上述操作后,我們得到了想要的數(shù)據(jù)??
1.3 Transformer Encoder
在圖一的中間部分,我們可以看到之前經(jīng)過處理的
被輸入到了Transformer Encoder層,而該層的具體結(jié)構(gòu)正如圖一右側(cè)所示,即下圖:

我們的?
與Transformer類似,我們這里的多頭是什么意思呢?
同樣的,我們想讓模型學(xué)習(xí)全方位、多層次、多角度的信息,學(xué)習(xí)更豐富的信息特征,對于同一張圖片來說,每個人看到的、注意到的部分都會存在一定差異,而在圖像中的多頭恰恰是把這些差異綜合起來進行學(xué)習(xí)。
1.4 MLP Head
結(jié)束了Transformer Encoder,就到了我們最終的分類處理部分,在之前我們進行Encoder的時候通過concat的方式多加了一個用于分類的可學(xué)習(xí)向量,這時我們把這個向量取出來輸入到MLP Head中,即經(jīng)過Layer Normal --> 全連接 --> GELU --> 全連接,我們得到了最終的輸出。
這里作者經(jīng)過實驗選取了GELU作為激活函數(shù)
03
2.1 庫導(dǎo)入
import torchfrom torch import nnfrom einops import rearrange, repeatfrom einops.layers.torch import Rearrange
這里的einops在我們后續(xù)對圖像進行塊切割時候會用到。
2.2 模型主體
def pair(t):return t if isinstance(t, tuple) else (t, t)class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),nn.Linear(patch_dim, dim),)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.to_patch_embedding(img)b, n, _ = x.shapecls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)x = torch.cat((cls_tokens, x), dim=1)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)x = self.transformer(x)x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)????????return?self.mlp_head(x)
從forward部分開始,我們可以看到輸入的img依次經(jīng)過了patch_embedding --> concat_cls_tokens --> add_pos_embedding --> transformer --> mlp_head,下面我們對這幾個部分進行逐一介紹:
2.2.1 patch_embedding
self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),nn.Linear(patch_dim, dim),????????)
這一步通過Rearrange將輸入為[b, c, h, w]的圖片切分為大小為p1*p2的圖片塊,同時通過Linear將維度從patch_dim縮放到dim。
2.2.2 concat_cls_tokens
經(jīng)過上一步后我們通過:
b,?n,?_?=?x.shape得到了輸入圖片的數(shù)量b,以及經(jīng)過切分后的圖片塊總數(shù)n。
接下來我們通過Parameter來生成一個可學(xué)習(xí)的變量:
?self.cls_token?=?nn.Parameter(torch.randn(1,?1,?dim))一個肯定是不夠的,我們通過repeat方法進行重復(fù):
?cls_tokens?=?repeat(self.cls_token,?'()?n?d?->?b?n?d',?b=b)??#?shape為[batch_size,?1,?dim]這樣就生成了一個shape為[b,1,dim]的向量,我們只需將其與原矩陣concat即可
x = torch.cat((cls_tokens, x), dim=1)這里需要注意,經(jīng)過concat后我們的n變?yōu)閚+1,會在下面的添加位置信息時用到。
2.2.3 add_pos_embedding
與生成可學(xué)習(xí)的??
self.pos_embedding?=?nn.Parameter(torch.randn(1,?num_patches?+?1,?dim))接下來我們只需通過逐元素加和的方式添加到原矩陣中去即可
?x?+=?self.pos_embedding[:,?:(n?+?1)]至此數(shù)據(jù)處理部分結(jié)束,接下來我們就要把X輸入到Transformer中去了。
2.3 Transformer部分
這一部分我單獨拎出來講解,首先上代碼:
class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + x????????return?x
這里的depth為Transformer Encoder的堆疊次數(shù),也即該部分深度,我們使用ModuleList既保持代碼整潔又實現(xiàn)了模塊堆疊。
繼續(xù)往下看可以發(fā)現(xiàn)每一層其實都是一個同樣的結(jié)構(gòu),即Attention部分 --> PreNorm --> Feed Forward部分 --> PreNorm。那么我們就分別來看一下這幾步的具體代碼。
首先來看Attention部分:
class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim = -1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')????????return?self.to_out(out)
從代碼中不難看出,我們輸入的X經(jīng)過變換生成Q、K、V
Q×K計算關(guān)聯(lián)性后進行一個 dim_head ** -0.5的維度縮放(此部分在Transformer中有介紹到),緊接著通過softmax計算權(quán)值再與原矩陣V相乘得到out,最后out經(jīng)過一個全連接層進行最終的輸出。
接下來是PreNorm部分:
class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):????????return?self.fn(self.norm(x),?**kwargs)
這一部分非常簡單,所要實現(xiàn)的就是一個層歸一化處理,這里不做過多介紹。
最后來到Feed Forward部分:
class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):????????return?self.net(x)
從代碼中可以看出,我們輸入的X進入到容器中,進行了一次全連接 --> GELU --> 全連接的變換
接下來對于Feed Forward的輸入,我們還要做一次層歸一化處理。
在Transformer Encoder部分,這樣的模塊堆疊depth次后,我們來到了最終的分類層。
2.4 MLP Head
在進入分類頭之前,我們需要把之前額外添加的分類專屬向量單獨提取出來:
x?=?x.mean(dim=1)?if?self.pool?==?'mean'?else?x[:,?0]在我們concat后,這個向量就是處于下標(biāo)為0的位置,故提取時只需輸入x[:, 0]即可。這里的mean是我們在輸入時的可選擇項(在2.2 模型主體部分的代碼中)
分類頭其實就是一個全連接層:
self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes)????????)
最終的num_classes即我們所需的圖像類別數(shù),至此整個VIT的代碼講解完畢。
04
本文對于Transformer部分的代碼講解不是足夠細(xì)致,只因其不是本文講解重點(后續(xù)會對本文Transformer部分代碼講解做出更新與改進),現(xiàn)有VIT模型的性能還需大量數(shù)據(jù)來訓(xùn)練(在論文中也有提出,小規(guī)模數(shù)據(jù)集的表現(xiàn)并不是很好),但作為繼DERT后的又一項CV與NLP結(jié)合的工作,引爆熱度是毋庸置疑的。
筆者才疏學(xué)淺,望廣大讀者批評指正,不吝賜教!
猜您喜歡:
附下載 |《TensorFlow 2.0 深度學(xué)習(xí)算法實戰(zhàn)》
附下載 |《計算機視覺中的數(shù)學(xué)方法》分享
《基于深度神經(jīng)網(wǎng)絡(luò)的少樣本學(xué)習(xí)綜述》
