實(shí)踐教程 | PyTorch中相對位置編碼的理解

極市導(dǎo)讀
本文重點(diǎn)討論BotNet中的2D相對位置編碼的實(shí)現(xiàn)中的一些細(xì)節(jié)。注意,這里的相對位置編碼方式和Swin Transformer中的不太一樣,讀者可以自行比較。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿
前言
這里討論的相對位置編碼的實(shí)現(xiàn)策略實(shí)際上原始來自于:https://arxiv.org/pdf/1809.04281.pdf。
這里有一篇介紹性的文章:https://gudgud96.github.io/2020/04/01/annotated-music-transformer/, 圖例非常清晰。
首先理解下相對位置自注意力中關(guān)于位置嵌入的一些細(xì)節(jié)。

相對注意力的一些相關(guān)概念。摘自Music Transformer。在不考慮head維度時(shí):
:相對位置嵌入,大小為 :來自Shaw論文中引入的相對位置嵌入的中間表示,大小為 :表示相對位置編碼與query的交互結(jié)果,大小為,即在維度上進(jìn)行了累加 Music Transformer的一點(diǎn)工作就是將這個(gè)會(huì)占用較大存儲(chǔ)空間的中間表示去掉,直接得到,如下圖所示:

要注意這里的表示的是針對相對位置的嵌入,最小相對位置為,最大為0(因?yàn)樾枰紤]因果關(guān)系,前面的i看不到后面的j),所以有個(gè)位置。
而對于我們這里將要討論的不考慮因果關(guān)系的情況,最小相對位置為,最大為。所以我們的位置嵌入形狀為。
代碼分析
首先找份代碼來看看, https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py 實(shí)現(xiàn)的相對位置編碼涉及到幾個(gè)關(guān)鍵的組件:
import torch
import torch.nn as nn
from einops import rearrange
def relative_to_absolute(q):
"""
Converts the dimension that is specified from the axis
from relative distances (with length 2*tokens-1) to absolute distance (length tokens)
borrowed from lucidrains:
https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L21
Input: [bs, heads, length, 2*length - 1]
Output: [bs, heads, length, length]
"""
b, h, l, _, device, dtype = *q.shape, q.device, q.dtype
dd = {'device': device, 'dtype': dtype}
col_pad = torch.zeros((b, h, l, 1), **dd)
x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l
flat_x = rearrange(x, 'b h l c -> b h (l c)')
flat_pad = torch.zeros((b, h, l - 1), **dd)
flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
final_x = final_x[:, :, :l, (l - 1):]
return final_x
def rel_pos_emb_1d(q, rel_emb, shared_heads):
"""
Same functionality as RelPosEmb1D
Args:
q: a 4d tensor of shape [batch, heads, tokens, dim]
rel_emb: a 2D or 3D tensor
of shape [ 2*tokens-1 , dim] or [ heads, 2*tokens-1 , dim]
"""
if shared_heads:
emb = torch.einsum('b h t d, r d -> b h t r', q, rel_emb)
else:
emb = torch.einsum('b h t d, h r d -> b h t r', q, rel_emb)
return relative_to_absolute(emb)
class RelPosEmb1DAISummer(nn.Module):
def __init__(self, tokens, dim_head, heads=None):
"""
Output: [batch head tokens tokens]
Args:
tokens: the number of the tokens of the seq
dim_head: the size of the last dimension of q
heads: if None representation is shared across heads.
else the number of heads must be provided
"""
super().__init__()
scale = dim_head ** -0.5
self.shared_heads = heads if heads is not None else True
if self.shared_heads:
self.rel_pos_emb = nn.Parameter(torch.randn(2 * tokens - 1, dim_head) * scale)
else:
self.rel_pos_emb = nn.Parameter(torch.randn(heads, 2 * tokens - 1, dim_head) * scale)
def forward(self, q):
return rel_pos_emb_1d(q, self.rel_pos_emb, self.shared_heads)
可以看到:
RelPosEmb1DAISummer初始化了rel_pos_emb_1d為relative_to_absolute提供(為了便于書寫,我們將其設(shè)為),通過在relative_to_absolute中各種形變和padding,從而得到了理解的難點(diǎn)在relative_to_absolute中的實(shí)現(xiàn)過程。
這里會(huì)把從一個(gè)tensor轉(zhuǎn)化為一個(gè)的tensor。這個(gè)過程實(shí)際上就是一個(gè)從表中查找的過程。
這里的實(shí)現(xiàn)其實(shí)有些晦澀,直接閱讀代碼是很難明白其中的意義。接下來會(huì)重點(diǎn)說這個(gè)。
需要注意的是,下面的分析都是按照1D的token序列來解釋的,實(shí)際上2D的也是將H和W分別基于1D的策略處理的。也就是將H或者W合并到頭索引那一維度,即這里的 heads,結(jié)果就和1D的一致了,只是還會(huì)多一個(gè)額外的廣播的過程。如下代碼:
import torch.nn as nn
from einops import rearrange
from self_attention_cv.pos_embeddings.relative_embeddings_1D import RelPosEmb1D
class RelPosEmb2DAISummer(nn.Module):
def __init__(self, feat_map_size, dim_head, heads=None):
"""
Based on Bottleneck transformer paper
paper: https://arxiv.org/abs/2101.11605 . Figure 4
Output: qr^T [batch head tokens tokens]
Args:
tokens: the number of the tokens of the seq
dim_head: the size of the last dimension of q
heads: if None representation is shared across heads.
else the number of heads must be provided
"""
super().__init__()
self.h, self.w = feat_map_size # height , width
self.total_tokens = self.h * self.w
self.shared_heads = heads if heads is not None else True
self.emb_w = RelPosEmb1D(self.h, dim_head, heads)
self.emb_h = RelPosEmb1D(self.w, dim_head, heads)
def expand_emb(self, r, dim_size):
# Decompose and unsqueeze dimension
r = rearrange(r, 'b (h x) i j -> b h x () i j', x=dim_size)
expand_index = [-1, -1, -1, dim_size, -1, -1] # -1 indicates no expansion
r = r.expand(expand_index)
return rearrange(r, 'b h x1 x2 y1 y2 -> b h (x1 y1) (x2 y2)')
def forward(self, q):
"""
Args:
q: [batch, heads, tokens, dim_head]
Returns: [ batch, heads, tokens, tokens]
"""
assert self.total_tokens == q.shape[2], f'Tokens {q.shape[2]} of q must \
be equal to the product of the feat map size {self.total_tokens} '
# out: [batch head*w h h]
r_h = self.emb_w(rearrange(q, 'b h (x y) d -> b (h x) y d', x=self.h, y=self.w))
r_w = self.emb_h(rearrange(q, 'b h (x y) d -> b (h y) x d', x=self.h, y=self.w))
q_r = self.expand_emb(r_h, self.h) + self.expand_emb(r_w, self.w)
return q_r
提前的思考
首先我們要明確,為什么對于每個(gè)維度為的token ,其對應(yīng)的整體會(huì)有這樣一個(gè)縮減的過程?
因?yàn)閷τ陂L為的序列中的每一個(gè)元素,實(shí)際上與之可能有關(guān)的元素最多只有個(gè)(雖說是廢話,但是在直接理解時(shí)可能確實(shí)容易忽略這一點(diǎn)。)。
所以對于每個(gè)元素,實(shí)際上這里的并不會(huì)都用到。這里的只是所有可能會(huì)用到的情形(分別對應(yīng)于各種相對距離)。
這里需要說明的一點(diǎn)是,有些相對注意力的策略中,會(huì)使用固定的窗口。
即對于窗口之外的j,和窗口邊界上的j的相對距離認(rèn)為是一樣的, 即,我們這里介紹的可以看做是。
例如這個(gè)實(shí)現(xiàn):https://github.com/TensorUI/relative-position-pytorch/blob/master/relative_position.py
所以這里前面展示的這個(gè)函數(shù) relative_to_absolute 實(shí)際上就是在做這樣一件事:從中抽取對應(yīng)于各個(gè)token真實(shí)存在的相對距離的位置嵌入集合來得到最終的.
背后的動(dòng)機(jī)
為了便于展示這個(gè)代碼描述的過程的動(dòng)機(jī),我們首先構(gòu)造一個(gè)簡單的序列,包含5個(gè)元素,則。這里嵌入的維度為。則位置對應(yīng)的相對距離矩陣可以表示為:

這里紅色標(biāo)記表示各個(gè)位置上的相對距離。我們再看下假定已經(jīng)得到的:

這里對各個(gè)都提供了獨(dú)立的一套嵌入。為了直觀的展示,這里我們也展示了對于這個(gè)相對位置的相對距離,同時(shí)也標(biāo)注了對應(yīng)于嵌入矩陣各列的絕對索引。
接下來我們就需要提取想要的那部分嵌入的tensor了。這個(gè)時(shí)候,我們需要明白,我們要獲取的是哪部分結(jié)果:

這里實(shí)際上就是結(jié)合了圖1中已經(jīng)得到的相對距離和圖2中的,從而就可以明白,紅色的這部分區(qū)域正是我們想要的那部分合理索引對應(yīng)的位置編碼。
稍微整理下, 也就是如下的絕對索引對應(yīng)的嵌入信息(形狀與一致,可以直接元素級(jí)相加):

而前面的代碼 relative_to_absolute 正是在做這樣一件事。就是通過不斷的 padding 和 reshape 來從圖3中獲得圖4中這些絕對索引對應(yīng)的嵌入。
對應(yīng)的流程
關(guān)于代碼的流程,參考鏈接中的圖例非常直觀:
col_pad = torch.zeros((b, h, l, 1), **dd)
x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l

flat_x = rearrange(x, 'b h l c -> b h (l c)')

flat_pad = torch.zeros((b, h, l - 1), **dd)
flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)

final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
final_x = final_x[:, :, :l, (l - 1):]

將提取的內(nèi)容對應(yīng)于原始的中,可以看到是如下區(qū)域,正如前面的分析所示。

參考
AI SUMMER這篇文章寫的很好,很直觀,很清晰:https://theaisummer.com/positional-embeddings/
如果覺得有用,就請分享到朋友圈吧!
長按掃描下方二維碼添加小助手。
可以一起討論遇到的問題
聲明:轉(zhuǎn)載請說明出處
掃描下方二維碼關(guān)注【集智書童】公眾號(hào),獲取更多實(shí)踐項(xiàng)目源碼和論文解讀,非常期待你我的相遇,讓我們以夢為馬,砥礪前行!

