transformer 中的 attention
來源:知乎—皮特潘
class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim = -1)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b h n d -> b n (h d)')????????return?self.to_out(out)
attention和CNN、RNN、FC、GCN等都是一個級別的東西,用來提取特征;既然是特征提取,一定有權重(W+B)存在。 attention的優(yōu)點:可以像CNN一樣并行運算 + 像RNN一樣通過一層就擁有全局資訊。有一個東西也可以做到,那就是FC,但是FC有個弱點:對輸入尺寸有限制,說白了不好適應可變輸入數(shù)據(jù),這對于序列無疑是非常不友好的。 可以像CNN一樣并行運算 ,其實CNN運算也是通過im2col或winograd等轉化為矩陣運算的。 RNN不能并行,所以通常它處理的數(shù)據(jù)有“時序”這個特點。既然是“時序”,那么就不是同一個時刻完成的,所以不能并行化。
batch維度:大家利用同樣的權重和操作提取特征,可以理解為for循環(huán)式,相互之間沒有信息交互; multi head維度:同batch類似,不過是利用的不同權重和相同操作提取特征,最后concate一起使用; FC層:是作用在每一個特征上,類似CNN中的1X1,可以叫“pointwise”,和序列長度沒有關系;因為序列中所有的特征經(jīng)過的是同一個FC。

猜您喜歡:
?戳我,查看GAN的系列專輯~!附下載 |《TensorFlow 2.0 深度學習算法實戰(zhàn)》
《基于深度神經(jīng)網(wǎng)絡的少樣本學習綜述》
評論
圖片
表情
