<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          Transformer Decoder-Only 模型批量生成 Trick

          共 2741字,需瀏覽 6分鐘

           ·

          2021-06-18 11:19

          ↑ 點(diǎn)擊藍(lán)字 關(guān)注極市平臺(tái)

          作者丨Andy
          來源丨安迪的寫作間
          編輯丨極市平臺(tái)

          極市導(dǎo)讀

           

          本文給出了一個(gè)用單Transformer decoder( GPT)模型進(jìn)行批量生成時(shí)的解決方法。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

          發(fā)現(xiàn)用單 Transformer decoder (Aka GPT)模型進(jìn)行生成時(shí),因?yàn)槲恢脤?duì)齊等問題,進(jìn)行批量生成時(shí)十分麻煩。

          訓(xùn)練時(shí),context 和 target 可以直接拼一起,然后一個(gè) Batch 內(nèi)通過裁剪或 Padding 到相同長(zhǎng)度來進(jìn)行批量訓(xùn)練。但生成時(shí),只有 context,每個(gè)長(zhǎng)度還不同,如果 Padding 到相同長(zhǎng)度,直接進(jìn)行生成的話,會(huì)讓生成階段和訓(xùn)練階段有巨大 gap,導(dǎo)致生成不了好的結(jié)果。

          解決問題的最好方法就是——不解決問題。直接一條條輸出吧。

          但如果不批量生成,模型小數(shù)據(jù)少時(shí)還好,站起來喝杯水撒泡尿時(shí)間就差不多了。但模型一大且數(shù)據(jù)量一大,花的時(shí)間就太大了。

          手動(dòng)開幾個(gè)進(jìn)程同時(shí)跑多個(gè)模型也不是不行,但太美了。

          所以只能想辦法解決了。

          訓(xùn)練階段解決

          通過 Padding 來解決的最主要問題是,生成和訓(xùn)練階段的差別太大,那是不是在訓(xùn)練時(shí)就給 Padding 直接放在 Context 后,再直接拼 target 就行。

          可行,但成本太大了,還得重訓(xùn)模型。

          所以還是不行。

          利用 Transformer 特性

          于是就想,如何通過處理讓生成時(shí)模擬訓(xùn)練時(shí)狀況,讓模型以為 target 位置是直接在 context 后,且只參考 context。

          需要明確一點(diǎn),Transformer 里因?yàn)槲恢眯畔⒅饕ㄟ^位置編碼來表示的,所以只要對(duì)應(yīng)的位置編碼不變,即使輸入向量順序再怎么變,對(duì) Transformer 來說還是差不多,這也是一些技巧如 PLM(Permutation Language Model)) 得以實(shí)現(xiàn)的原理。

          直接這樣說太抽象了,舉個(gè)栗子。

          假設(shè)一個(gè) batch 長(zhǎng)度不一樣樣本訓(xùn)練時(shí)如下

          input_ids:
          1 3 2 6 2 0 0
          1 3 6 2 5 4 2

          2 為分割和終止符,可看到訓(xùn)練時(shí),通過給第一句 padding,算 loss 時(shí) padding 位置都不算上來進(jìn)行訓(xùn)練。

          而 inference 時(shí),只有 context,即使 padding 也會(huì)是下面這樣

          1 3 2 0
          1 3 6 2

          這種情況下如果直接用默認(rèn)的 pos_ids 和 atten_mask (不了解的看The Annotated Transformer),第一句就會(huì)出現(xiàn)問題。

          對(duì)比一下,訓(xùn)練時(shí)用到的三個(gè)參數(shù)

          1 3 2 6 2 0 0 (input_ids)
          0 1 2 3 4 0 0 (pos_ids)
          1 1 1 1 1 0 0 (atten_mask)

          訓(xùn)練時(shí)當(dāng)生成 6 的時(shí)候看到的是

          1 3 2
          0 1 2
          1 1 1

          再來看看生成時(shí)的情況,生成 6 的時(shí)候直接看到的是

          1 3 2 0
          0 1 2 3
          1 1 1 0

          首先拿的是最后 padding 位置的向量來預(yù)測(cè)下一個(gè),同時(shí)還有個(gè)問題就是,當(dāng)預(yù)測(cè)完成一個(gè)時(shí),之后拿到的位置 id 是不對(duì)的,這里假設(shè)預(yù)測(cè)成功為 6

          1 3 2 0 6
          0 1 2 3 4
          1 1 1 0 1

          會(huì)發(fā)現(xiàn)用 6 來預(yù)測(cè)下一個(gè)詞時(shí)已經(jīng)和訓(xùn)練時(shí)不一樣了,因?yàn)橛?xùn)練時(shí) 6 對(duì)應(yīng)的位置 id 是 3

          實(shí)際這樣用時(shí),我也發(fā)現(xiàn)生成結(jié)果總是錯(cuò)開幾個(gè)字,像是給刀直接切開了一樣。

          于是改進(jìn),最簡(jiǎn)單方法是直接給 padding 的位置向量都設(shè)成 padding 前的位置,這樣預(yù)測(cè)時(shí)位置向量就對(duì)了。

          1 3 2 0 6
          0 1 2 2 3
          1 1 1 0 1

          但這只解決了一個(gè)問題,即生成過程中的問題,第一個(gè)位置拿的還是 padding 位置進(jìn)行的輸出。這里有個(gè)解決方法,就是生成時(shí),第一次預(yù)測(cè)取到 padding 前 token,之后就依次取最后一個(gè)進(jìn)行預(yù)測(cè)了

          這樣基本上就算是解決問題了,但生成時(shí)第一次和之后還得區(qū)分開,說實(shí)話還是有點(diǎn) ugly.

          還可以進(jìn)一步優(yōu)化。

          Left Padding

          解決方法很簡(jiǎn)單,思維掉轉(zhuǎn)下就好了,因?yàn)椴⑿猩蓵r(shí)都是從最后一位開始取,那能不能直接給 padding 放到前面去呢。

          于是生成時(shí)一個(gè) batch 會(huì)變成這樣

          input_ids:
          0 0 1 3 2
          1 3 6 2 5

          那么對(duì)于第一條進(jìn)行預(yù)測(cè)時(shí),也只需要這樣設(shè)置一下 pos_id 和 atten_mask 就行

          0 1 3 2
          0 0 1 2
          0 1 1 1

          這樣子生成 6 時(shí),位置向量就能自然而然銜接上,同時(shí) atten_mask 也給前面的 padding 完美 mask 掉了。

          完美解決!速度一下提高了好幾倍。

          如果覺得有用,就請(qǐng)分享到朋友圈吧!

          △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨

          公眾號(hào)后臺(tái)回復(fù)“目標(biāo)檢測(cè)競(jìng)賽”獲取目標(biāo)檢測(cè)競(jìng)賽經(jīng)驗(yàn)資源~


          極市干貨
          YOLO教程:一文讀懂YOLO V5 與 YOLO V4大盤點(diǎn)|YOLO 系目標(biāo)檢測(cè)算法總覽全面解析YOLO V4網(wǎng)絡(luò)結(jié)構(gòu)
          實(shí)操教程:PyTorch vs LibTorch:網(wǎng)絡(luò)推理速度誰(shuí)更快?只用兩行代碼,我讓Transformer推理加速了50倍PyTorch AutoGrad C++層實(shí)現(xiàn)
          算法技巧(trick):深度學(xué)習(xí)訓(xùn)練tricks總結(jié)(有實(shí)驗(yàn)支撐)深度強(qiáng)化學(xué)習(xí)調(diào)參Tricks合集長(zhǎng)尾識(shí)別中的Tricks匯總(AAAI2021
          最新CV競(jìng)賽:2021 高通人工智能應(yīng)用創(chuàng)新大賽CVPR 2021 | Short-video Face Parsing Challenge3D人體目標(biāo)檢測(cè)與行為分析競(jìng)賽開賽,獎(jiǎng)池7萬+,數(shù)據(jù)集達(dá)16671張!


          CV技術(shù)社群邀請(qǐng)函 #

          △長(zhǎng)按添加極市小助手
          添加極市小助手微信(ID : cvmart2)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


          即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~



          覺得有用麻煩給個(gè)在看啦~  
          瀏覽 70
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  黄色在线视频网站 | 五十路義母 | 欧美日本视频在线 | 亚洲字幕成人中文在线观看 | 日本黄色一区二区 |