<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          保姆級(jí)硬核教程:圖解Transformer

          共 13831字,需瀏覽 28分鐘

           ·

          2021-01-21 20:10

          點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

          設(shè)為星標(biāo),干貨直達(dá)!


          一、前言

          大家好,我是 Jack。

          本文是圖解 AI 算法系列教程的第二篇,今天的主角是 Transformer

          Transformer 可以做很多有趣而又有意義的事情。

          比如我寫(xiě)過(guò)的《用自己訓(xùn)練的AI玩王者榮耀是什么體驗(yàn)?》。

          再比如 OpenAIDALL·E,可以魔法一般地按照自然語(yǔ)言文字描述直接生成對(duì)應(yīng)圖片!

          輸入文本:鱷梨形狀的扶手椅。

          AI 生成的圖像:

          兩者都是多模態(tài)的應(yīng)用,這也是各大巨頭的跟進(jìn)方向,可謂大勢(shì)所趨

          Transformer 最初主要應(yīng)用于一些自然語(yǔ)言處理場(chǎng)景,比如翻譯、文本分類(lèi)、寫(xiě)小說(shuō)、寫(xiě)歌等。

          隨著技術(shù)的發(fā)展,Transformer 開(kāi)始征戰(zhàn)視覺(jué)領(lǐng)域,分類(lèi)、檢測(cè)等任務(wù)均不在話下,逐漸走上了多模態(tài)的道路。

          Transformer 近兩年非常火爆,內(nèi)容也很多,要想講清楚,還涉及一些基于該結(jié)構(gòu)的預(yù)訓(xùn)練模型,例如著名的 BERTGPT,以及剛出的 DALL·E 等。

          它們都是基于 Transformer 的上層應(yīng)用,因?yàn)?Transformer 很難訓(xùn)練,巨頭們就肩負(fù)起了造福大眾的使命,開(kāi)源了各種好用的預(yù)訓(xùn)練模型

          我們都是站在巨人肩膀上學(xué)習(xí),用開(kāi)源的預(yù)訓(xùn)練模型在一些特定的應(yīng)用場(chǎng)景進(jìn)行遷移學(xué)習(xí)

          篇幅有限,本文先講解 Transformer 的基礎(chǔ)原理,希望每個(gè)人都可以看懂。

          后面我會(huì)繼續(xù)寫(xiě) BERTGPT 等內(nèi)容,更新可能慢一些,但是跟著學(xué),絕對(duì)都能有所收獲。

          還是那句話:如果你喜歡這個(gè) AI 算法系列教程,一定要讓我知道,轉(zhuǎn)發(fā)在看支持,更文更有動(dòng)力!

          二、Transformer

          TransformerGoogle2017 年提出的用于機(jī)器翻譯的模型。

          Transformer 的內(nèi)部,在本質(zhì)上是一個(gè) Encoder-Decoder 的結(jié)構(gòu),即 編碼器-解碼器

          Transformer 中拋棄了傳統(tǒng)的 CNNRNN,整個(gè)網(wǎng)絡(luò)結(jié)構(gòu)完全由 Attention 機(jī)制組成,并且采用了 6Encoder-Decoder 結(jié)構(gòu)。

          顯然,Transformer 主要分為兩大部分,分別是編碼器解碼器

          整個(gè) Transformer 是由 6 個(gè)這樣的結(jié)構(gòu)組成,為了方便理解,我們只看其中一個(gè)Encoder-Decoder 結(jié)構(gòu)。

          以一個(gè)簡(jiǎn)單的例子進(jìn)行說(shuō)明:

          Why do we work?,我們?yōu)槭裁垂ぷ鳎?/p>

          左側(cè)紅框是編碼器,右側(cè)紅框是解碼器

          編碼器負(fù)責(zé)把自然語(yǔ)言序列映射成為隱藏層(上圖第2步),即含有自然語(yǔ)言序列的數(shù)學(xué)表達(dá)。

          解碼器把隱藏層再映射為自然語(yǔ)言序列,從而使我們可以解決各種問(wèn)題,如情感分析、機(jī)器翻譯、摘要生成、語(yǔ)義關(guān)系抽取等。

          簡(jiǎn)單說(shuō)下,上圖每一步都做了什么:

          • 輸入自然語(yǔ)言序列到編碼器: Why do we work?(為什么要工作);
          • 編碼器輸出的隱藏層,再輸入到解碼器;
          • 輸入 <??????????> (起始)符號(hào)到解碼器;
          • 解碼器得到第一個(gè)字"為";
          • 將得到的第一個(gè)字"為"落下來(lái)再輸入到解碼器;
          • 解碼器得到第二個(gè)字"什";
          • 將得到的第二字再落下來(lái),直到解碼器輸出 <??????> (終止符),即序列生成完成。

          解碼器和編碼器的結(jié)構(gòu)類(lèi)似,本文以編碼器部分進(jìn)行講解。即把自然語(yǔ)言序列映射為隱藏層的數(shù)學(xué)表達(dá)的過(guò)程,因?yàn)槔斫饬司幋a器中的結(jié)構(gòu),理解解碼器就非常簡(jiǎn)單了。

          為了方便學(xué)習(xí),我將編碼器分為 4 個(gè)部分,依次講解。

          1、位置嵌入(???????????????????? ????????????????)

          我們輸入數(shù)據(jù) X 維度為[batch size, sequence length]的數(shù)據(jù),比如我們?yōu)槭裁垂ぷ?/code>。

          batch size 就是 batch 的大小,這里只有一句話,所以 batch size1sequence length 是句子的長(zhǎng)度,一共 7 個(gè)字,所以輸入的數(shù)據(jù)維度是 [1, 7]

          我們不能直接將這句話輸入到編碼器中,因?yàn)?Tranformer 不認(rèn)識(shí),我們需要先進(jìn)行字嵌入,即得到圖中的

          簡(jiǎn)單點(diǎn)說(shuō),就是文字->字向量的轉(zhuǎn)換,這種轉(zhuǎn)換是將文字轉(zhuǎn)換為計(jì)算機(jī)認(rèn)識(shí)的數(shù)學(xué)表示,用到的方法就是 Word2VecWord2Vec 的具體細(xì)節(jié),對(duì)于初學(xué)者暫且不用了解,這個(gè)是可以直接使用的。

          得到的 的維度是 [batch size, sequence length, embedding dimension]embedding dimension 的大小由 Word2Vec 算法決定,Tranformer 采用 512 長(zhǎng)度的字向量。所以 的維度是 [1, 7, 512]

          至此,輸入的我們?yōu)槭裁垂ぷ?/code>,可以用一個(gè)矩陣來(lái)簡(jiǎn)化表示。

          我們知道,文字的先后順序,很重要。

          比如吃飯沒(méi)沒(méi)吃飯沒(méi)飯吃飯吃沒(méi)飯沒(méi)吃,同樣三個(gè)字,順序顛倒,所表達(dá)的含義就不同了。

          文字的位置信息很重要,Tranformer 沒(méi)有類(lèi)似 RNN 的循環(huán)結(jié)構(gòu),沒(méi)有捕捉順序序列的能力。

          為了保留這種位置信息交給 Tranformer 學(xué)習(xí),我們需要用到位置嵌入

          加入位置信息的方式非常多,最簡(jiǎn)單的可以是直接將絕對(duì)坐標(biāo) 0,1,2 編碼。

          Tranformer 采用的是 sin-cos 規(guī)則,使用了 sincos 函數(shù)的線性變換來(lái)提供給模型位置信息:

          上式中 pos 指的是句中字的位置,取值范圍是 [0, ?????? ???????????????? ???????????)i 指的是字嵌入的維度, 取值范圍是 [0, ?????????????????? ??????????????????) 就是 ?????????????????? ?????????????????? 的大小。

          上面有 sincos 一組公式,也就是對(duì)應(yīng)著 ?????????????????? ?????????????????? 維度的一組奇數(shù)和偶數(shù)的序號(hào)的維度,從而產(chǎn)生不同的周期性變化。

          可以用代碼,簡(jiǎn)單看下效果。

          # 導(dǎo)入依賴(lài)庫(kù)
          import numpy as np
          import matplotlib.pyplot as plt
          import seaborn as sns
          import math

          def get_positional_encoding(max_seq_len, embed_dim):
              # 初始化一個(gè)positional encoding
              # embed_dim: 字嵌入的維度
              # max_seq_len: 最大的序列長(zhǎng)度
              positional_encoding = np.array([
                  [pos / np.power(100002 * i / embed_dim) for i in range(embed_dim)]
                  if pos != 0 else np.zeros(embed_dim) for pos in range(max_seq_len)])
              positional_encoding[1:, 0::2] = np.sin(positional_encoding[1:, 0::2])  # dim 2i 偶數(shù)
              positional_encoding[1:, 1::2] = np.cos(positional_encoding[1:, 1::2])  # dim 2i+1 奇數(shù)
              # 歸一化, 用位置嵌入的每一行除以它的模長(zhǎng)
              # denominator = np.sqrt(np.sum(position_enc**2, axis=1, keepdims=True))
              # position_enc = position_enc / (denominator + 1e-8)
              return positional_encoding
              
          positional_encoding = get_positional_encoding(max_seq_len=100, embed_dim=16)
          plt.figure(figsize=(10,10))
          sns.heatmap(positional_encoding)
          plt.title("Sinusoidal Function")
          plt.xlabel("hidden dimension")
          plt.ylabel("sequence length")

          可以看到,位置嵌入在 ?????????????????? ?????????????????? (也是hidden dimension )維度上隨著維度序號(hào)增大,周期變化會(huì)越來(lái)越慢,而產(chǎn)生一種包含位置信息的紋理。

          就這樣,產(chǎn)生獨(dú)一的紋理位置信息,模型從而學(xué)到位置之間的依賴(lài)關(guān)系和自然語(yǔ)言的時(shí)序特性。

          最后,將 位置嵌入 相加,送給下一層。

          2、自注意力層(???????? ?????????????????? ?????????????????)

          直接看下圖筆記,講解的非常詳細(xì)。

          多頭的意義在于, 得到的矩陣就叫注意力矩陣,它可以表示每個(gè)字與其他字的相似程度。因?yàn)椋蛄康狞c(diǎn)積值越大,說(shuō)明兩個(gè)向量越接近。

          我們的目的是,讓每個(gè)字都含有當(dāng)前這個(gè)句子中的所有字的信息,用注意力層,我們做到了。

          需要注意的是,在上面 ???????? ?????????????????? 的計(jì)算過(guò)程中,我們通常使用 ???????? ?????????,也就是一次計(jì)算多句話,上文舉例只用了一個(gè)句子。

          每個(gè)句子的長(zhǎng)度是不一樣的,需要按照最長(zhǎng)的句子的長(zhǎng)度統(tǒng)一處理。對(duì)于短的句子,進(jìn)行 Padding 操作,一般我們用 0 來(lái)進(jìn)行填充。

          3、殘差鏈接和層歸一化

          加入了殘差設(shè)計(jì)和層歸一化操作,目的是為了防止梯度消失,加快收斂。

          1) 殘差設(shè)計(jì)

          我們?cè)谏弦徊降玫搅私?jīng)過(guò)注意力矩陣加權(quán)之后的 ??, 也就是 ??????????????????(??, ??, ??),我們對(duì)它進(jìn)行一下轉(zhuǎn)置,使其和 ???????????????????? 的維度一致, 也就是 [????????? ????????, ???????????????? ???????????, ?????????????????? ??????????????????] ,然后把他們加起來(lái)做殘差連接,直接進(jìn)行元素相加,因?yàn)樗麄兊木S度一致:

          在之后的運(yùn)算里,每經(jīng)過(guò)一個(gè)模塊的運(yùn)算,都要把運(yùn)算之前的值和運(yùn)算之后的值相加,從而得到殘差連接,訓(xùn)練的時(shí)候可以使梯度直接走捷徑反傳到最初始層:

          2) 層歸一化

          作用是把神經(jīng)網(wǎng)絡(luò)中隱藏層歸一為標(biāo)準(zhǔn)正態(tài)分布,也就是 ??.??.?? 獨(dú)立同分布, 以起到加快訓(xùn)練速度, 加速收斂的作用。

          上式中以矩陣的行 (??????) 為單位求均值:

          上式中以矩陣的行 (??????) 為單位求方差:

          然后用每一行每一個(gè)元素減去這行的均值,再除以這行的標(biāo)準(zhǔn)差,從而得到歸一化后的數(shù)值,是為了防止除

          之后引入兩個(gè)可訓(xùn)練參數(shù)來(lái)彌補(bǔ)歸一化的過(guò)程中損失掉的信息,注意表示元素相乘而不是點(diǎn)積,我們一般初始化為全,而為全

          代碼層面非常簡(jiǎn)單,單頭 attention 操作如下:

          class ScaledDotProductAttention(nn.Module):
              ''' Scaled Dot-Product Attention '''

              def __init__(self, temperature, attn_dropout=0.1):
                  super().__init__()
                  self.temperature = temperature
                  self.dropout = nn.Dropout(attn_dropout)

              def forward(self, q, k, v, mask=None):
                  # self.temperature是論文中的d_k ** 0.5,防止梯度過(guò)大
                  # QxK/sqrt(dk)
                  attn = torch.matmul(q / self.temperature, k.transpose(23))

                  if mask is not None:
                      # 屏蔽不想要的輸出
                      attn = attn.masked_fill(mask == 0-1e9)
                  # softmax+dropout
                  attn = self.dropout(F.softmax(attn, dim=-1))
                  # 概率分布xV
                  output = torch.matmul(attn, v)

                  return output, attn

          Multi-Head Attention 實(shí)現(xiàn)在 ScaledDotProductAttention 基礎(chǔ)上構(gòu)建:

          class MultiHeadAttention(nn.Module):
              ''' Multi-Head Attention module '''

              # n_head頭的個(gè)數(shù),默認(rèn)是8
              # d_model編碼向量長(zhǎng)度,例如本文說(shuō)的512
              # d_k, d_v的值一般會(huì)設(shè)置為 n_head * d_k=d_model,
              # 此時(shí)concat后正好和原始輸入一樣,當(dāng)然不相同也可以,因?yàn)楹竺嬗衒c層
              # 相當(dāng)于將可學(xué)習(xí)矩陣分成獨(dú)立的n_head份
              def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
                  super().__init__()
                  # 假設(shè)n_head=8,d_k=64
                  self.n_head = n_head
                  self.d_k = d_k
                  self.d_v = d_v
                  # d_model輸入向量,n_head * d_k輸出向量
                  # 可學(xué)習(xí)W^Q,W^K,W^V矩陣參數(shù)初始化
                  self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
                  self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
                  self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
                  # 最后的輸出維度變換操作
                  self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
                  # 單頭自注意力
                  self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
                  self.dropout = nn.Dropout(dropout)
                  # 層歸一化
                  self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

              def forward(self, q, k, v, mask=None):
                  # 假設(shè)qkv輸入是(b,100,512),100是訓(xùn)練每個(gè)樣本最大單詞個(gè)數(shù)
                  # 一般qkv相等,即自注意力
                  residual = q
                  # 將輸入x和可學(xué)習(xí)矩陣相乘,得到(b,100,512)輸出
                  # 其中512的含義其實(shí)是8x64,8個(gè)head,每個(gè)head的可學(xué)習(xí)矩陣為64維度
                  # q的輸出是(b,100,8,64),kv也是一樣
                  q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
                  k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
                  v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

                  # 變成(b,8,100,64),方便后面計(jì)算,也就是8個(gè)頭單獨(dú)計(jì)算
                  q, k, v = q.transpose(12), k.transpose(12), v.transpose(12)

                  if mask is not None:
                      mask = mask.unsqueeze(1)   # For head axis broadcasting.
                  # 輸出q是(b,8,100,64),維持不變,內(nèi)部計(jì)算流程是:
                  # q*k轉(zhuǎn)置,除以d_k ** 0.5,輸出維度是b,8,100,100即單詞和單詞直接的相似性
                  # 對(duì)最后一個(gè)維度進(jìn)行softmax操作得到b,8,100,100
                  # 最后乘上V,得到b,8,100,64輸出
                  q, attn = self.attention(q, k, v, mask=mask)

                  # b,100,8,64-->b,100,512
                  q = q.transpose(12).contiguous().view(sz_b, len_q, -1)
                  q = self.dropout(self.fc(q))
                  # 殘差計(jì)算
                  q += residual
                  # 層歸一化,在512維度計(jì)算均值和方差,進(jìn)行層歸一化
                  q = self.layer_norm(q)

                  return q, attn

          4、前饋網(wǎng)絡(luò)

          這個(gè)層就沒(méi)啥說(shuō)的了,非常簡(jiǎn)單,直接看代碼吧:

          class PositionwiseFeedForward(nn.Module):
              ''' A two-feed-forward-layer module '''

              def __init__(self, d_in, d_hid, dropout=0.1):
                  super().__init__()
                  # 兩個(gè)fc層,對(duì)最后的512維度進(jìn)行變換
                  self.w_1 = nn.Linear(d_in, d_hid) # position-wise
                  self.w_2 = nn.Linear(d_hid, d_in) # position-wise
                  self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
                  self.dropout = nn.Dropout(dropout)

              def forward(self, x):
                  residual = x

                  x = self.w_2(F.relu(self.w_1(x)))
                  x = self.dropout(x)
                  x += residual

                  x = self.layer_norm(x)

                  return x

          最后,回顧下 ?????????????????????? ?????????????? 的整體結(jié)構(gòu)。

          經(jīng)過(guò)上文的梳理,我們已經(jīng)基本了解了 ?????????????????????? 編碼器的主要構(gòu)成部分,我們下面用公式把一個(gè) ?????????????????????? ?????????? 的計(jì)算過(guò)程整理一下:

          1) 字向量與位置編碼
          2) 自注意力機(jī)制
          3) 殘差連接與層歸一化
          4) 前向網(wǎng)絡(luò)

          其實(shí)就是兩層線性映射并用激活函數(shù)激活,比如說(shuō):

          5) 重復(fù)3)

          三、絮叨

          至此,我們已經(jīng)講完了 Transformer 編碼器的全部?jī)?nèi)容,知道了如何獲得自然語(yǔ)言的位置信息,注意力機(jī)制的工作原理等。

          本文以原理講解為主,后續(xù)我會(huì)繼續(xù)更新實(shí)戰(zhàn)內(nèi)容,教大家如何訓(xùn)練我們自己的有趣又好玩的模型。

          本文硬核,肝了很久,如果喜歡,還望轉(zhuǎn)發(fā)、再看多多支持。

          我是 Jack ,我們下期見(jiàn)。

          ·················END·················



          推薦閱讀

          PyTorch 源碼解讀之 torch.autograd

          Transformer為何能闖入CV界秒殺CNN?

          SWA:讓你的目標(biāo)檢測(cè)模型無(wú)痛漲點(diǎn)1% AP

          CondInst:性能和速度均超越Mask RCNN的實(shí)例分割模型

          centerX: 用新的視角的方式打開(kāi)CenterNet

          mmdetection最小復(fù)刻版(十一):概率Anchor分配機(jī)制PAA深入分析

          MMDetection新版本V2.7發(fā)布,支持DETR,還有YOLOV4在路上!

          CNN:我不是你想的那樣

          TF Object Detection 終于支持TF2了!

          無(wú)需tricks,知識(shí)蒸餾提升ResNet50在ImageNet上準(zhǔn)確度至80%+

          不妨試試MoCo,來(lái)替換ImageNet上pretrain模型!

          重磅!一文深入深度學(xué)習(xí)模型壓縮和加速

          從源碼學(xué)習(xí)Transformer!

          mmdetection最小復(fù)刻版(七):anchor-base和anchor-free差異分析

          mmdetection最小復(fù)刻版(四):獨(dú)家yolo轉(zhuǎn)化內(nèi)幕


          機(jī)器學(xué)習(xí)算法工程師


                                              一個(gè)用心的公眾號(hào)


           


          瀏覽 69
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                    <th id="afajh"><progress id="afajh"></progress></th>
                    91在线无码精品在线看 | 成人久久视频 | 人妖乱伦视频 | 国产播放在线 | 97一区二区 |