最強(qiáng)輔助Visualizer:簡(jiǎn)化你的Vision Transformer可視化!
點(diǎn)擊上方“視學(xué)算法”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
?作者 | 洛英
Visualizer 是一個(gè)輔助深度學(xué)習(xí)模型中 Attention 模塊可視化的小工具,主要功能是幫助取出嵌套在模型深處的 Attention Map。
為了可視化 Attention Map,你是否有以下苦惱:
1. Return 大法好:通過 return 將嵌套在模型深處的 Attention Map 一層層地返回回來,然后訓(xùn)練模型的時(shí)候又不得不還原;
2. 全局大法好:使用全局變量在 Attention 函數(shù)中直接記錄 Attention Map,結(jié)果訓(xùn)練的時(shí)候忘改回來導(dǎo)致 OOM。
不管你有沒有,反正我有,由于可視化分析不是一錘子買賣,實(shí)際過程中你往往需要在訓(xùn)練-可視化-訓(xùn)練-可視化兩種狀態(tài)下反復(fù)橫跳,所以不適合采用以上兩種方式進(jìn)行可視化分析。

PyTorch hook 的局限性
handle = net.conv2.register_forward_hook(hook)
這樣我們就可以拿出來 net.conv2 這層的輸出啦。
然而!進(jìn)行這樣操作的前提是我們知道要取出來的模塊名,但是 Transformer 類模型一般是這樣定義的(以 Vit 為例)。
class VisionTransformer(nn.Module):
def __init__(self, *args, **kwargs):
...
self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
...然后每個(gè) Block 中都有一個(gè) Attention 。
class Block(nn.Module):
def __init__(self, *args, **kwargs):
...
self.attn = Attention(...)
...然后我們想要的 attention map 又在 Attention 里面。
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
...
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) # <-在這
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x1. 嵌套太深,模塊名不清晰,我們根本不知道我們要取的 attention map 怎么以 model.bla.bla.bla 這樣一直點(diǎn)出來!
2. 一般來說,Transformer 中 attention map 每層都有一個(gè),一個(gè)個(gè)注冊(cè)實(shí)在太麻煩了。
那怎么辦呢....
Visualizer!
所以我就思考并查找能否通過更簡(jiǎn)潔的方法來得到 Attention Map(尤其是 Transformer 的),而 visualizer 就是其中的一種,它具有以下特點(diǎn):
精準(zhǔn)直接,你可以取出任何變量名的模型中間結(jié)果;
快捷方便,一個(gè)操作,就可以同時(shí)取出 Transformer 類模型中的所有 attention map;
非侵入式,你無須修改函數(shù)內(nèi)的任何一行代碼;
訓(xùn)練-測(cè)試一致,可視化完成后,訓(xùn)練時(shí)無須再將代碼改回來。

python setup.py install使用方法一
from visualizer import get_local
@get_local('attention_map') # 我要拿attention_map這個(gè)變量,所以把他傳參給get_local
def your_attention_function(*args, **kwargs):
...
attention_map = ...
...
return ...在可視化代碼里,我們這么寫:
from visualizer import get_local
get_local.activate() # 激活裝飾器
from ... import model # 被裝飾的模型一定要在裝飾器激活之后導(dǎo)入??!
# load model and data
...
out = model(data)
cache = get_local.cache # -> {'your_attention_function': [attention_map]}使用 Pytorch 時(shí)我們往往會(huì)將模塊定義成一個(gè)類,此時(shí)也是一樣只要裝飾類內(nèi)計(jì)算出 attention_map 的函數(shù)即可:
from visualizer import get_local
class Attention(nn.Module):
def __init__(self):
...
@get_local('attn_map')
def forward(self, x):
...
attn_map = ...
...
return ...其他細(xì)節(jié)請(qǐng)參考:

可視化結(jié)果
因?yàn)槠胀?Vit 所有 Attention map 都是在 Attention.forward 中計(jì)算出來的,所以只要簡(jiǎn)單地裝飾一下這個(gè)函數(shù),我們就可以同時(shí)取出 vit 中 12 層 Transformer 的所有 Attention Map!
一個(gè) Head 的結(jié)果:

一層所有 Heads 的結(jié)果:


在可視化這張圖片的過程中,我也發(fā)現(xiàn)了一些有趣的現(xiàn)象。
首先,靠前層的 Attention 大多只關(guān)注自身,進(jìn)行真·self attention 來理解自身的信息,比如這是第一層所有 Head 的 Attention Map,其特點(diǎn)就是呈現(xiàn)出明顯的對(duì)角線模式。



最后,重要信息聚合到某些特定的 token 上,Attention 出現(xiàn)與 query 無關(guān)的情況,在 Attention Map 上呈現(xiàn)出豎線的模式,如下第 11 層的 Attention Map:

注意
在使用 visualizer 的過程中,有以下幾點(diǎn)需要注意:
1. 想要可視化的變量在函數(shù)內(nèi)部不能被后續(xù)的同名變量覆蓋了,因?yàn)?get_local 取的是對(duì)應(yīng)名稱變量在函數(shù)中的最終值;
2. 進(jìn)行可視化時(shí),get_local.activate() 一定要在導(dǎo)入模型前完成,因?yàn)?python 裝飾器是在導(dǎo)入時(shí)執(zhí)行的;
3. 訓(xùn)練時(shí)你不需要?jiǎng)h除裝飾的代碼,因?yàn)樵?get_local.activate() 沒有執(zhí)行的情況下,attention 函數(shù)不會(huì)被裝飾,故沒有任何性能損失(同上一點(diǎn),因?yàn)?python 裝飾器是在導(dǎo)入時(shí)執(zhí)行的)。
其他
當(dāng)然,其實(shí) get_local 本質(zhì)就是獲取一個(gè)函數(shù)中某個(gè)局部變量的最終值,所以它應(yīng)該還有其他更有趣的用途。

小結(jié)

點(diǎn)個(gè)在看 paper不斷!
