圖解 transformer 中的自注意力機制
共 13352字,需瀏覽 27分鐘
·
2024-04-22 08:13
↓推薦關注↓
注意力機制
在整個注意力過程中,模型會學習了三個權(quán)重:查詢、鍵和值。查詢、鍵和值的思想來源于信息檢索系統(tǒng)。所以我們先理解數(shù)據(jù)庫查詢的思想。
假設有一個數(shù)據(jù)庫,里面有所有一些作家和他們的書籍信息。現(xiàn)在我想讀一些Rabindranath寫的書:
在數(shù)據(jù)庫中,作者名字類似于鍵,圖書類似于值。查詢的關鍵詞Rabindranath是這個問題的鍵。所以需要計算查詢和數(shù)據(jù)庫的鍵(數(shù)據(jù)庫中的所有作者)之間的相似度,然后返回最相似作者的值(書籍)。
同樣,注意力有三個矩陣,分別是查詢矩陣(Q)、鍵矩陣(K)和值矩陣(V)。它們中的每一個都具有與輸入嵌入相同的維數(shù)。模型在訓練中學習這些度量的值。
我們可以假設我們從每個單詞中創(chuàng)建一個向量,這樣我們就可以處理信息。對于每個單詞,生成一個512維的向量。所有3個矩陣都是512x512(因為單詞嵌入的維度是512)。對于每個標記嵌入,我們將其與所有三個矩陣(Q, K, V)相乘,每個標記將有3個長度為512的中間向量。
接下來計算分數(shù),它是查詢和鍵向量之間的點積。分數(shù)決定了當我們在某個位置編碼單詞時,對輸入句子的其他部分的關注程度。
然后將點積除以關鍵向量維數(shù)的平方根。這種縮放是為了防止點積變得太大或太小(取決于正值或負值),因為這可能導致訓練期間的數(shù)值不穩(wěn)定。選擇比例因子是為了確保點積的方差近似等于1。
然后通過softmax操作傳遞結(jié)果。這將分數(shù)標準化:它們都是正的,并且加起來等于1。softmax輸出決定了我們應該從不同的單詞中獲取多少信息或特征(值),也就是在計算權(quán)重。
這里需要注意的一點是,為什么需要其他單詞的信息/特征?因為我們的語言是有上下文含義的,一個相同的單詞出現(xiàn)在不同的語境,含義也不一樣。
最后一步就是計算softmax與這些值的乘積,并將它們相加。
可視化圖解
上面邏輯都是文字內(nèi)容,看起來有一些枯燥,下面我們可視化它的矢量化實現(xiàn)。這樣可以更加深入的理解。
查詢鍵和矩陣的計算方法如下
同樣的方法可以計算鍵向量和值向量。
最后計算得分和注意力輸出。
簡單代碼實現(xiàn)
import torch
import torch.nn as nn
from typing import List
def get_input_embeddings(words: List[str], embeddings_dim: int):
# we are creating random vector of embeddings_dim size for each words
# normally we train a tokenizer to get the embeddings.
# check the blog on tokenizer to learn about this part
embeddings = [torch.randn(embeddings_dim) for word in words]
return embeddings
text = "I should sleep now"
words = text.split(" ")
len(words) # 4
embeddings_dim = 512 # 512 dim because the original paper uses it. we can use other dim also
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])
# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])
# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])
# compute the score
scores = torch.matmul(query_vectors, key_vectors.transpose(-2, -1)) / torch.sqrt(torch.tensor(embeddings_dim, dtype=torch.float32))
scores.shape # torch.Size([4, 4])
# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights = softmax(scores)
attention_weights.shape # torch.Size([4, 4])
# attention output
output = torch.matmul(attention_weights, value_vectors)
output.shape # torch.Size([4, 512])
以上代碼只是為了展示注意力機制的實現(xiàn),并未優(yōu)化。
多頭注意力
上面提到的注意力是單頭注意力,在原論文中有8個頭。對于多頭和單多頭注意力計算相同,只是查詢(q0-q3),鍵(k0-k3),值(v0-v3)中間向量會有一些區(qū)別。
之后將查詢向量分成相等的部分(有多少頭就分成多少)。在上圖中有8個頭,查詢,鍵和值向量的維度為512。所以就變?yōu)榱?個64維的向量。
把前64個向量放到第一個頭,第二組向量放到第二個頭,以此類推。在上面的圖片中,我只展示了第一個頭的計算。
這里需要注意的是:不同的框架有不同的實現(xiàn)方法,pytorch官方的實現(xiàn)是上面這種,但是tf和一些第三方的代碼中是將每個頭分開計算了,比如8個頭會使用8個linear(tf的dense)而不是一個大linear再拆解。還記得Pytorch的transformer里面要求emb_dim能被num_heads整除嗎,就是因為這個
使用哪種方式都可以,因為最終的結(jié)果都類似影響不大。
當我們在一個head中有了小查詢、鍵和值(64 dim的)之后,計算剩下的邏輯與單個head注意相同。最后得到的64維的向量來自每個頭。
我們將每個頭的64個輸出組合起來,得到最后的512個dim輸出向量。
多頭注意力可以表示數(shù)據(jù)中的復雜關系。每個頭都能學習不同的模式。多個頭還提供了同時處理輸入表示的不同子空間(本例:64個向量表示512個原始向量)的能力。
多頭注意代碼實現(xiàn)
num_heads = 8
# batch dim is 1 since we are processing one text.
batch_size = 1
text = "I should sleep now"
words = text.split(" ")
len(words) # 4
embeddings_dim = 512
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])
# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])
# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])
# (batch_size, num_heads, seq_len, embeddings_dim)
query_vectors_view = query_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
key_vectors_view = key_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
value_vectors_view = value_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
query_vectors_view.shape, key_vectors_view.shape, value_vectors_view.shape
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64])
# We are splitting the each vectors into 8 heads.
# Assuming we have one text (batch size of 1), So we split
# the embedding vectors also into 8 parts. Each head will
# take these parts. If we do this one head at a time.
head1_query_vector = query_vectors_view[0, 0, ...]
head1_key_vector = key_vectors_view[0, 0, ...]
head1_value_vector = value_vectors_view[0, 0, ...]
head1_query_vector.shape, head1_key_vector.shape, head1_value_vector.shape
# The above vectors are of same size as before only the feature dim is changed from 512 to 64
# compute the score
scores_head1 = torch.matmul(head1_query_vector, head1_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
scores_head1.shape # torch.Size([4, 4])
# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights_head1 = softmax(scores_head1)
attention_weights_head1.shape # torch.Size([4, 4])
output_head1 = torch.matmul(attention_weights_head1, head1_value_vector)
output_head1.shape # torch.Size([4, 512])
# we can compute the output for all the heads
outputs = []
for head_idx in range(num_heads):
head_idx_query_vector = query_vectors_view[0, head_idx, ...]
head_idx_key_vector = key_vectors_view[0, head_idx, ...]
head_idx_value_vector = value_vectors_view[0, head_idx, ...]
scores_head_idx = torch.matmul(head_idx_query_vector, head_idx_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
softmax = nn.Softmax(dim=-1)
attention_weights_idx = softmax(scores_head_idx)
output = torch.matmul(attention_weights_idx, head_idx_value_vector)
outputs.append(output)
[out.shape for out in outputs]
# [torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64])]
# stack the result from each heads for the corresponding words
word0_outputs = torch.cat([out[0] for out in outputs])
word0_outputs.shape
# lets do it for all the words
attn_outputs = []
for i in range(len(words)):
attn_output = torch.cat([out[i] for out in outputs])
attn_outputs.append(attn_output)
[attn_output.shape for attn_output in attn_outputs] # [torch.Size([512]), torch.Size([512]), torch.Size([512]), torch.Size([512])]
# Now lets do it in vectorize way.
# We can not permute the last two dimension of the key vector.
key_vectors_view.permute(0, 1, 3, 2).shape # torch.Size([1, 8, 64, 4])
# Transpose the key vector on the last dim
score = torch.matmul(query_vectors_view, key_vectors_view.permute(0, 1, 3, 2)) # Q*k
score = torch.softmax(score, dim=-1)
# reshape the results
attention_results = torch.matmul(score, value_vectors_view)
attention_results.shape # [1, 8, 4, 64]
# merge the results
attention_results = attention_results.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embeddings_dim)
attention_results.shape # torch.Size([1, 4, 512])
總結(jié)
注意力機制(attention mechanism)是Transformer模型中的重要組成部分。Transformer是一種基于自注意力機制(self-attention)的神經(jīng)網(wǎng)絡模型,廣泛應用于自然語言處理任務,如機器翻譯、文本生成和語言模型等。本文介紹的自注意力機制是Transformer模型的基礎,在此基礎之上衍生發(fā)展出了各種不同的更加高效的注意力機制,所以深入了解自注意力機制,將能夠更好地理解Transformer模型的設計原理和工作機制,以及如何在具體的各種任務中應用和調(diào)整模型。這將有助于你更有效地使用Transformer模型并進行相關研究和開發(fā)。
最后有興趣的可以看看這個,它里面包含了pytorch的transformer的完整實現(xiàn):https://www.kaggle.com/code/hengck23/lb-0-67-one-pytorch-transformer-solution
- EOF -
星球服務
知識星球是一個面向 全體學生和在職人員 的技術(shù)交流平臺,旨在為大家提供社招/校招準備攻略、面試題庫、面試經(jīng)驗、學習路線、求職答疑、項目實戰(zhàn)案例、內(nèi)推機會等內(nèi)容,幫你快速成長、告別迷茫。
涉及Python,數(shù)據(jù)分析,數(shù)據(jù)挖掘,機器學習,深度學習,大數(shù)據(jù),搜光推、自然語言處理、計算機視覺、web 開發(fā)、大模型、多模態(tài)、Langchain、擴散模型、知識圖譜等方向。
我們會不定期開展知識星球立減優(yōu)惠活動,加入星球前可以添加城哥微信:dkl88191,咨詢詳情。
技術(shù)學習資料如下,星球成員可免費獲取2個,非星球成員,添加城哥微信:dkl88191,可以單獨購買。
