圖解swin transformer
引言
目前Transformer應(yīng)用到圖像領(lǐng)域主要有兩大挑戰(zhàn):
視覺(jué)實(shí)體變化大,在不同場(chǎng)景下視覺(jué)Transformer性能未必很好 圖像分辨率高,像素點(diǎn)多,Transformer基于全局自注意力的計(jì)算導(dǎo)致計(jì)算量較大
針對(duì)上述兩個(gè)問(wèn)題,我們提出了一種包含滑窗操作,具有層級(jí)設(shè)計(jì)的Swin Transformer。
其中滑窗操作包括不重疊的local window,和重疊的cross-window。將注意力計(jì)算限制在一個(gè)窗口中,一方面能引入CNN卷積操作的局部性,另一方面能節(jié)省計(jì)算量。

在各大圖像任務(wù)上,Swin Transformer都具有很好的性能。
本文比較長(zhǎng),會(huì)根據(jù)官方的開(kāi)源代碼(https://github.com/microsoft/Swin-Transformer)進(jìn)行講解,有興趣的可以去閱讀下論文原文(https://arxiv.org/pdf/2103.14030.pdf)。
整體架構(gòu)
我們先看下Swin Transformer的整體架構(gòu)

整個(gè)模型采取層次化的設(shè)計(jì),一共包含4個(gè)Stage,每個(gè)stage都會(huì)縮小輸入特征圖的分辨率,像CNN一樣逐層擴(kuò)大感受野。
在輸入開(kāi)始的時(shí)候,做了一個(gè) Patch Embedding,將圖片切成一個(gè)個(gè)圖塊,并嵌入到Embedding。在每個(gè)Stage里,由 Patch Merging和多個(gè)Block組成。其中 Patch Merging模塊主要在每個(gè)Stage一開(kāi)始降低圖片分辨率。而B(niǎo)lock具體結(jié)構(gòu)如右圖所示,主要是 LayerNorm,MLP,Window Attention和Shifted Window Attention組成 (為了方便講解,我會(huì)省略掉一些參數(shù))
class SwinTransformer(nn.Module):
def __init__(...):
super().__init__()
...
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(...)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
其中有幾個(gè)地方處理方法與ViT不同:
ViT在輸入會(huì)給embedding進(jìn)行位置編碼。而Swin-T這里則是作為一個(gè)可選項(xiàng)( self.ape),Swin-T是在計(jì)算Attention的時(shí)候做了一個(gè)相對(duì)位置編碼ViT會(huì)單獨(dú)加上一個(gè)可學(xué)習(xí)參數(shù),作為分類的token。而Swin-T則是直接做平均,輸出分類,有點(diǎn)類似CNN最后的全局平均池化層
接下來(lái)我們看下各個(gè)組件的構(gòu)成
Patch Embedding
在輸入進(jìn)Block前,我們需要將圖片切成一個(gè)個(gè)patch,然后嵌入向量。
具體做法是對(duì)原始圖片裁成一個(gè)個(gè) window_size * window_size的窗口大小,然后進(jìn)行嵌入。
這里可以通過(guò)二維卷積層,將stride,kernelsize設(shè)置為window_size大小。設(shè)定輸出通道來(lái)確定嵌入向量的大小。最后將H,W維度展開(kāi),并移動(dòng)到第一維度
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size) # -> (img_size, img_size)
patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
# 假設(shè)采取默認(rèn)參數(shù)
x = self.proj(x) # 出來(lái)的是(N, 96, 224/4, 224/4)
x = torch.flatten(x, 2) # 把HW維展開(kāi),(N, 96, 56*56)
x = torch.transpose(x, 1, 2) # 把通道維放到最后 (N, 56*56, 96)
if self.norm is not None:
x = self.norm(x)
return x
Patch Merging
該模塊的作用是在每個(gè)Stage開(kāi)始前做降采樣,用于縮小分辨率,調(diào)整通道數(shù) 進(jìn)而形成層次化的設(shè)計(jì),同時(shí)也能節(jié)省一定運(yùn)算量。
在CNN中,則是在每個(gè)Stage開(kāi)始前用
stride=2的卷積/池化層來(lái)降低分辨率。
每次降采樣是兩倍,因此在行方向和列方向上,間隔2選取元素。
然后拼接在一起作為一整個(gè)張量,最后展開(kāi)。此時(shí)通道維度會(huì)變成原先的4倍(因?yàn)镠,W各縮小2倍),此時(shí)再通過(guò)一個(gè)全連接層再調(diào)整通道維度為原來(lái)的兩倍
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
下面是一個(gè)示意圖(輸入張量N=1, H=W=8, C=1,不包含最后的全連接層調(diào)整)

個(gè)人感覺(jué)這像是PixelShuffle的反操作
Window Partition/Reverse
window partition函數(shù)是用于對(duì)張量劃分窗口,指定窗口大小。將原本的張量從 N H W C, 劃分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / window_size,即窗口的個(gè)數(shù)。而window reverse函數(shù)則是對(duì)應(yīng)的逆過(guò)程。這兩個(gè)函數(shù)會(huì)在后面的Window Attention用到。
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
Window Attention
這是這篇文章的關(guān)鍵。傳統(tǒng)的Transformer都是基于全局來(lái)計(jì)算注意力的,因此計(jì)算復(fù)雜度十分高。而Swin Transformer則將注意力的計(jì)算限制在每個(gè)窗口內(nèi),進(jìn)而減少了計(jì)算量。
我們先簡(jiǎn)單看下公式
主要區(qū)別是在原始計(jì)算Attention的公式中的Q,K時(shí)加入了相對(duì)位置編碼。后續(xù)實(shí)驗(yàn)有證明相對(duì)位置編碼的加入提升了模型性能。
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads # nH
head_dim = dim // num_heads # 每個(gè)注意力頭對(duì)應(yīng)的通道數(shù)
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 設(shè)置一個(gè)形狀為(2*(Wh-1) * 2*(Ww-1), nH)的可學(xué)習(xí)變量,用于后續(xù)的位置編碼
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
# 相關(guān)位置編碼...
下面我把涉及到相關(guān)位置編碼的邏輯給單獨(dú)拿出來(lái),這部分比較繞
首先QK計(jì)算出來(lái)的Attention張量形狀為(numWindows*B, num_heads, window_size*window_size, window_size*window_size)。
而對(duì)于Attention張量來(lái)說(shuō),以不同元素為原點(diǎn),其他元素的坐標(biāo)也是不同的,以window_size=2為例,其相對(duì)位置編碼如下圖所示

首先我們利用torch.arange和torch.meshgrid函數(shù)生成對(duì)應(yīng)的坐標(biāo),這里我們以windowsize=2為例子
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
(tensor([[0, 0],
[1, 1]]),
tensor([[0, 1],
[0, 1]]))
"""
然后堆疊起來(lái),展開(kāi)為一個(gè)二維向量
coords = torch.stack(coords) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
"""
利用廣播機(jī)制,分別在第一維,第二維,插入一個(gè)維度,進(jìn)行廣播相減,得到 2, wh*ww, wh*ww的張量
relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second # 最終得到 2, wh*ww, wh*ww 形狀的張量
因?yàn)椴扇〉氖窍鄿p,所以得到的索引是從負(fù)數(shù)開(kāi)始的,我們加上偏移量,讓其從0開(kāi)始。
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
后續(xù)我們需要將其展開(kāi)成一維偏移量。而對(duì)于(1,2)和(2,1)這兩個(gè)坐標(biāo)。在二維上是不同的,但是通過(guò)將x,y坐標(biāo)相加轉(zhuǎn)換為一維偏移的時(shí)候,他的偏移量是相等的。

所以最后我們對(duì)其中做了個(gè)乘法操作,以進(jìn)行區(qū)分
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

然后再最后一維上進(jìn)行求和,展開(kāi)成一個(gè)一維坐標(biāo),并注冊(cè)為一個(gè)不參與網(wǎng)絡(luò)學(xué)習(xí)的變量
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
接著我們看前向代碼
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) # (1, num_heads, windowsize, windowsize)
if mask is not None: # 下文會(huì)分析到
...
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
首先輸入張量形狀為
numWindows*B, window_size * window_size, C(后續(xù)會(huì)解釋)然后經(jīng)過(guò)
self.qkv這個(gè)全連接層后,進(jìn)行reshape,調(diào)整軸的順序,得到形狀為3, numWindows*B, num_heads, window_size*window_size, c//num_heads,并分配給q,k,v。根據(jù)公式,我們對(duì)
q乘以一個(gè)scale縮放系數(shù),然后與k(為了滿足矩陣乘要求,需要將最后兩個(gè)維度調(diào)換)進(jìn)行相乘。得到形狀為(numWindows*B, num_heads, window_size*window_size, window_size*window_size)的attn張量之前我們針對(duì)位置編碼設(shè)置了個(gè)形狀為
(2*window_size-1*2*window_size-1, numHeads)的可學(xué)習(xí)變量。我們用計(jì)算得到的相對(duì)編碼位置索引self.relative_position_index選取,得到形狀為(window_size*window_size, window_size*window_size, numHeads)的編碼,加到attn張量上暫不考慮mask的情況,剩下就是跟transformer一樣的softmax,dropout,與
V矩陣乘,再經(jīng)過(guò)一層全連接層和dropout
Shifted Window Attention
前面的Window Attention是在每個(gè)窗口下計(jì)算注意力的,為了更好的和其他window進(jìn)行信息交互,Swin Transformer還引入了shifted window操作。

左邊是沒(méi)有重疊的Window Attention,而右邊則是將窗口進(jìn)行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相鄰窗口的元素。但這也引入了一個(gè)新問(wèn)題,即window的個(gè)數(shù)翻倍了,由原本四個(gè)窗口變成了9個(gè)窗口。
在實(shí)際代碼里,我們是通過(guò)對(duì)特征圖移位,并給Attention設(shè)置mask來(lái)間接實(shí)現(xiàn)的。能在保持原有的window個(gè)數(shù)下,最后的計(jì)算結(jié)果等價(jià)。
特征圖移位操作
代碼里對(duì)特征圖移位是通過(guò)torch.roll來(lái)實(shí)現(xiàn)的,下面是示意圖

如果需要
reverse cyclic shift的話只需把參數(shù)shifts設(shè)置為對(duì)應(yīng)的正數(shù)值。
Attention Mask
我認(rèn)為這是Swin Transformer的精華,通過(guò)設(shè)置合理的mask,讓Shifted Window Attention在與Window Attention相同的窗口個(gè)數(shù)下,達(dá)到等價(jià)的計(jì)算結(jié)果。
首先我們對(duì)Shift Window后的每個(gè)窗口都給上index,并且做一個(gè)roll操作(window_size=2, shift_size=1)

我們希望在計(jì)算Attention的時(shí)候,讓具有相同index QK進(jìn)行計(jì)算,而忽略不同index QK計(jì)算結(jié)果。
最后正確的結(jié)果如下圖所示

而要想在原始四個(gè)窗口下得到正確的結(jié)果,我們就必須給Attention的結(jié)果加入一個(gè)mask(如上圖最右邊所示)
相關(guān)代碼如下:
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
以上圖的設(shè)置,我們用這段代碼會(huì)得到這樣的一個(gè)mask
tensor([[[[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]],
[[[ 0., -100., 0., -100.],
[-100., 0., -100., 0.],
[ 0., -100., 0., -100.],
[-100., 0., -100., 0.]]],
[[[ 0., 0., -100., -100.],
[ 0., 0., -100., -100.],
[-100., -100., 0., 0.],
[-100., -100., 0., 0.]]],
[[[ 0., -100., -100., -100.],
[-100., 0., -100., -100.],
[-100., -100., 0., -100.],
[-100., -100., -100., 0.]]]]])
在之前的window attention模塊的前向代碼里,包含這么一段
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
將mask加到attention的計(jì)算結(jié)果,并進(jìn)行softmax。mask的值設(shè)置為-100,softmax后就會(huì)忽略掉對(duì)應(yīng)的值
Transformer Block整體架構(gòu)

兩個(gè)連續(xù)的Block架構(gòu)如上圖所示,需要注意的是一個(gè)Stage包含的Block個(gè)數(shù)必須是偶數(shù),因?yàn)樾枰惶姘粋€(gè)含有Window Attention的Layer和含有Shifted Window Attention的Layer。
我們看下Block的前向代碼
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
整體流程如下
先對(duì)特征圖進(jìn)行LayerNorm 通過(guò) self.shift_size決定是否需要對(duì)特征圖進(jìn)行shift然后將特征圖切成一個(gè)個(gè)窗口 計(jì)算Attention,通過(guò) self.attn_mask來(lái)區(qū)分Window Attention還是Shift Window Attention將各個(gè)窗口合并回來(lái) 如果之前有做shift操作,此時(shí)進(jìn)行 reverse shift,把之前的shift操作恢復(fù)做dropout和殘差連接 再通過(guò)一層LayerNorm+全連接層,以及dropout和殘差連接
實(shí)驗(yàn)結(jié)果

在ImageNet22K數(shù)據(jù)集上,準(zhǔn)確率能達(dá)到驚人的86.4%。另外在檢測(cè),分割等任務(wù)上表現(xiàn)也很優(yōu)異,感興趣的可以翻看論文最后的實(shí)驗(yàn)部分。
總結(jié)
這篇文章創(chuàng)新點(diǎn)很棒,引入window這一個(gè)概念,將CNN的局部性引入,還能控制模型整體計(jì)算量。在Shift Window Attention部分,用一個(gè)mask和移位操作,很巧妙的實(shí)現(xiàn)計(jì)算等價(jià)。作者的代碼也寫(xiě)得十分賞心悅目,推薦閱讀!
歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅(jiān)持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識(shí)。( ? ?ω?? )?
有對(duì)文章相關(guān)的問(wèn)題,或者想要加入交流群,歡迎添加BBuf微信:
