從零搭建Pytorch模型教程 | 搭建Transformer網絡
點擊下方“AI算法與圖像處理”,一起進步!
重磅干貨,第一時間送達
前言?本文介紹了Transformer的基本流程,分塊的兩種實現方式,Position Emebdding的幾種實現方式,Encoder的實現方式,最后分類的兩種方式,以及最重要的數據格式的介紹。

分塊
目前有兩種方式實現分塊,一種是直接分割,一種是通過卷積核和步長都為patch大小的卷積來分割。
直接分割
from?einops?import?rearrange,?repeat
from?einops.layers.torch?import?Rearrange
self.to_patch_embedding?=?nn.Sequential(
???????????Rearrange('b?c?(h?p1)?(w?p2)?->?b?(h?w)?(p1?p2?c)',?p1?=?patch_height,?p2?=?patch_width),
???????????nn.Linear(patch_dim,?dim),
??????)
#假設images的shape為[32,200,400,3]
#實現view和reshape的功能
Rearrange(images,'b?h?w?c?->?(b?h)?w?c')#shape變?yōu)椋?2*200,?400,?3)
#實現permute的功能
Rearrange(images,?'b?h?w?c?->?b?c?h?w')#shape變?yōu)椋?2,?3,?200,?400)
#實現這幾個都很難實現的功能
Rearrange(images,?'b?h?w?c?->?(b?c?w)?h')#shape變?yōu)椋?2*3*400,?200)
Rearrange('b?c?(h?p1)?(w?p2)?->?b?(h?w)?(p1?p2?c)',?p1?=?patch_height,?p2?=?patch_width)
卷積分割
self.proj?=?nn.Conv2d(in_chans,?embed_dim,?kernel_size=patch_size,?stride=patch_size)
x?=?self.proj(x).flatten(2).transpose(1,?2)??#?B?Ph*Pw?C
Position Embedding
self.pos_embedding?=?nn.Parameter(torch.randn(1,?num_patches?+?1,?dim))
x?+=?self.pos_embedding[:,?:(n?+?1)]
#之所以是n+1,是因為ViT中選擇隨機初始化一個class token,與分塊得到的tokens拼接。所以patches的數量為num_patches+1。
from?timm.models.layers?import?trunc_normal_
self.absolute_pos_embed?=?nn.Parameter(torch.zeros(1,?num_patches,?embed_dim))
trunc_normal_(self.absolute_pos_embed,?std=.02)
self.pos_emb?=?torch.nn.Embedding(num_positions?+?1,?dim)
Encoder
Multi-head Self-attention
class?Attention(nn.Module):
???def?__init__(self,?dim,?heads?=?8,?dim_head?=?64,?dropout?=?0.):
???????super().__init__()
???????inner_dim?=?dim_head?*??heads
???????project_out?=?not?(heads?==?1?and?dim_head?==?dim)
???????self.heads?=?heads
???????self.scale?=?dim_head?**?-0.5
???????self.attend?=?nn.Softmax(dim?=?-1)
???????self.dropout?=?nn.Dropout(dropout)
???????self.to_qkv?=?nn.Linear(dim,?inner_dim?*?3,?bias?=?False)
???????self.to_out?=?nn.Sequential(
???????????nn.Linear(inner_dim,?dim),
???????????nn.Dropout(dropout)
??????)?if?project_out?else?nn.Identity()
???def?forward(self,?x):
???????qkv?=?self.to_qkv(x).chunk(3,?dim?=?-1)
???????q,?k,?v?=?map(lambda?t:?rearrange(t,?'b?n?(h?d)?->?b?h?n?d',?h?=?self.heads),?qkv)
???????dots?=?torch.matmul(q,?k.transpose(-1,?-2))?*?self.scale
???????attn?=?self.attend(dots)
???????attn?=?self.dropout(attn)
???????out?=?torch.matmul(attn,?v)
???????out?=?rearrange(out,?'b?h?n?d?->?b?n?(h?d)')
???????return?self.to_out(out)
FeedForward
class?FeedForward(nn.Module):
???def?__init__(self,?dim,?hidden_dim,?dropout?=?0.):
???????super().__init__()
???????self.net?=?nn.Sequential(
???????????nn.Linear(dim,?hidden_dim),
???????????nn.GELU(),
???????????nn.Dropout(dropout),
???????????nn.Linear(hidden_dim,?dim),
???????????nn.Dropout(dropout)
??????)
???def?forward(self,?x):
???????return?self.net(x)
class?Transformer(nn.Module):
???def?__init__(self,?dim,?depth,?heads,?dim_head,?mlp_dim,?dropout?=?0.):
???????super().__init__()
???????self.layers?=?nn.ModuleList([])
???????for?_?in?range(depth):
???????????self.layers.append(nn.ModuleList([
???????????????PreNorm(dim,?Attention(dim,?heads?=?heads,?dim_head?=?dim_head,?dropout?=?dropout)),
???????????????PreNorm(dim,?FeedForward(dim,?mlp_dim,?dropout?=?dropout))
??????????]))
???def?forward(self,?x):
???????for?attn,?ff?in?self.layers:
???????????x?=?attn(x)?+?x
???????????x?=?ff(x)?+?x
???????return?x
class?PreNorm(nn.Module):
????def?__init__(self,?dim,?fn):
????????super().__init__()
????????self.norm?=?nn.LayerNorm(dim)
????????self.fn?=?fn
????def?forward(self,?x,?**kwargs):
????????return?self.fn(self.norm(x),?**kwargs)
分類方法
#生成cls_token部分
from?einops?import?repeat
self.cls_token?=?nn.Parameter(torch.randn(1,?1,?dim))
cls_tokens?=?repeat(self.cls_token,?'1?n?d?->?b?n?d',?b?=?b)
x?=?torch.cat((cls_tokens,?x),?dim=1)
################################
#分類部分
self.mlp_head?=?nn.Sequential(
???????????nn.LayerNorm(dim),
???????????nn.Linear(dim,?num_classes)
??????)
x?=?x.mean(dim?=?1)?if?self.pool?==?'mean'?else?x[:,?0]
x?=?self.to_latent(x)
return?self.mlp_head(x)
在swin transformer中,沒有選擇cls_token。而是直接在經過Encoder后將所有數據取了個平均池化,再通過全連接層。
self.avgpool?=?nn.AdaptiveAvgPool1d(1)
self.head?=?nn.Linear(self.num_features,?num_classes)?if?num_classes?>?0?else?nn.Identity()
x?=?self.avgpool(x.transpose(1,?2))??#?B?C?1
x?=?torch.flatten(x,?1)
x?=?self.head(x)
組合以上這些就成了一個完整的模型
class?ViT(nn.Module):
???def?__init__(self,?*,?image_size,?patch_size,?num_classes,?dim,?depth,?heads,?mlp_dim,?pool?=?'cls',?channels?=?3,?dim_head?=?64,?dropout?=?0.,?emb_dropout?=?0.):
???????super().__init__()
???????image_height,?image_width?=?pair(image_size)
???????patch_height,?patch_width?=?pair(patch_size)
???????num_patches?=?(image_height?//?patch_height)?*?(image_width?//?patch_width)
???????patch_dim?=?channels?*?patch_height?*?patch_width
???????assert?pool?in?{'cls',?'mean'},?'pool?type?must?be?either?cls?(cls?token)?or?mean?(mean?pooling)'
???????self.to_patch_embedding?=?nn.Sequential(
???????????Rearrange('b?c?(h?p1)?(w?p2)?->?b?(h?w)?(p1?p2?c)',?p1?=?patch_height,?p2?=?patch_width),
???????????nn.Linear(patch_dim,?dim),
??????)
???????self.pos_embedding?=?nn.Parameter(torch.randn(1,?num_patches?+?1,?dim))
???????self.cls_token?=?nn.Parameter(torch.randn(1,?1,?dim))
???????self.dropout?=?nn.Dropout(emb_dropout)
???????self.transformer?=?Transformer(dim,?depth,?heads,?dim_head,?mlp_dim,?dropout)
???????self.pool?=?pool
???????self.to_latent?=?nn.Identity()
???????self.mlp_head?=?nn.Sequential(
???????????nn.LayerNorm(dim),
???????????nn.Linear(dim,?num_classes)
??????)
???def?forward(self,?img):
???????x?=?self.to_patch_embedding(img)
???????b,?n,?_?=?x.shape
???????cls_tokens?=?repeat(self.cls_token,?'1?n?d?->?b?n?d',?b?=?b)
???????x?=?torch.cat((cls_tokens,?x),?dim=1)
???????x?+=?self.pos_embedding[:,?:(n?+?1)]
???????x?=?self.dropout(x)
???????x?=?self.transformer(x)
???????x?=?x.mean(dim?=?1)?if?self.pool?==?'mean'?else?x[:,?0]
???????x?=?self.to_latent(x)
???????return?self.mlp_head(x)
數據的變換
Rearrange('b?c?(h?p1)?(w?p2)?->?b?(h?w)?(p1?p2?c)',?p1?=?patch_height,?p2?=?patch_width)
ViT:https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
swin:?https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
TimeSformer:https://github.com/lucidrains/TimeSformer-pytorch/blob/main/timesformer_pytorch/timesformer_pytorch.py
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有美顏、三維視覺、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群
個人微信(如果沒有備注不拉群!) 請注明:地區(qū)+學校/企業(yè)+研究方向+昵稱
下載1:何愷明頂會分享
在「AI算法與圖像處理」公眾號后臺回復:何愷明,即可下載。總共有6份PDF,涉及 ResNet、Mask RCNN等經典工作的總結分析
下載2:終身受益的編程指南:Google編程風格指南
在「AI算法與圖像處理」公眾號后臺回復:c++,即可下載。歷經十年考驗,最權威的編程規(guī)范!
下載3 CVPR2021 在「AI算法與圖像處理」公眾號后臺回復:CVPR,即可下載1467篇CVPR?2020論文 和 CVPR 2021 最新論文
評論
圖片
表情

