【深度學習】如何理解attention中的Q,K,V?
共 7648字,需瀏覽 16分鐘
·
2024-06-20 12:00
來源 | 知乎問答
地址 | https://www.zhihu.com/question/298810062
本文僅作學術(shù)分享,若侵權(quán)請聯(lián)系后臺刪文處理
回答一:作者-不是大叔
class BertSelfAttention(nn.Module):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
2. 假設三種操作的輸入都是同一個矩陣(暫且先別管為什么輸入是同一個矩陣),這里暫且定為長度為L的句子,每個token的特征維度是768,那么輸入就是(L, 768),每一行就是一個字,像這樣:
乘以上面三種操作就得到了Q/K/V,(L, 768)*(768,768) = (L,768),維度其實沒變,即此刻的Q/K/V分別為:
代碼為:
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
def forward(self,hidden_states): # hidden_states 維度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
3. 然后來實現(xiàn)這個操作:
① 首先是Q和K矩陣乘,(L, 768)*(L, 768)的轉(zhuǎn)置=(L,L),看圖:
③ 然后就是剛才的注意力權(quán)重和V矩陣乘了,如圖:
整個過程在草稿紙上畫一畫簡單的矩陣乘就出來了,一目了然~最后上代碼:
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 輸入768, 輸出768
def forward(self,hidden_states): # hidden_states 維度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
attention_scores = torch.matmul(Q, K.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
out = torch.matmul(attention_probs, V)
return out
回答二:作者-到處挖坑蔣玉成
回答三:作者-渠梁
首先,Attention機制是由Encoder-Decoder架構(gòu)而來,且最初是用于完成NLP領域中的翻譯(Translation)任務。那么輸入輸出就是非常明顯的 Source-Target的對應關系,經(jīng)典的Seq2Seq結(jié)構(gòu)是從Encoder生成出一個語義向量(Context vector)而不再變化,然后將這個語義向量送入Decoder配合解碼輸出。這種方法的最大問題就是這個語義向量,我們是希望它一成不變好呢?還是它最好能配合Decoder動態(tài)調(diào)整自己,來使Target中的某些token與Source中的真正“有決定意義”的token關聯(lián)起來好呢?
往期精彩回顧
交流群
歡迎加入機器學習愛好者微信群一起和同行交流,目前有機器學習交流群、博士群、博士申報交流、CV、NLP等微信群,請掃描下面的微信號加群,備注:”昵稱-學校/公司-研究方向“,例如:”張小明-浙大-CV“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~(也可以加入機器學習交流qq群772479961)
評論
圖片
表情
