Transformer Decoder-Only 模型批量生成 Trick

極市導(dǎo)讀
本文給出了一個(gè)用單Transformer decoder( GPT)模型進(jìn)行批量生成時(shí)的解決方法。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

發(fā)現(xiàn)用單 Transformer decoder (Aka GPT)模型進(jìn)行生成時(shí),因?yàn)槲恢脤?duì)齊等問題,進(jìn)行批量生成時(shí)十分麻煩。
訓(xùn)練時(shí),context 和 target 可以直接拼一起,然后一個(gè) Batch 內(nèi)通過裁剪或 Padding 到相同長(zhǎng)度來進(jìn)行批量訓(xùn)練。但生成時(shí),只有 context,每個(gè)長(zhǎng)度還不同,如果 Padding 到相同長(zhǎng)度,直接進(jìn)行生成的話,會(huì)讓生成階段和訓(xùn)練階段有巨大 gap,導(dǎo)致生成不了好的結(jié)果。
解決問題的最好方法就是——不解決問題。直接一條條輸出吧。

但如果不批量生成,模型小數(shù)據(jù)少時(shí)還好,站起來喝杯水撒泡尿時(shí)間就差不多了。但模型一大且數(shù)據(jù)量一大,花的時(shí)間就太大了。
手動(dòng)開幾個(gè)進(jìn)程同時(shí)跑多個(gè)模型也不是不行,但太美了。
所以只能想辦法解決了。
訓(xùn)練階段解決
通過 Padding 來解決的最主要問題是,生成和訓(xùn)練階段的差別太大,那是不是在訓(xùn)練時(shí)就給 Padding 直接放在 Context 后,再直接拼 target 就行。
可行,但成本太大了,還得重訓(xùn)模型。
所以還是不行。
利用 Transformer 特性
于是就想,如何通過處理讓生成時(shí)模擬訓(xùn)練時(shí)狀況,讓模型以為 target 位置是直接在 context 后,且只參考 context。
需要明確一點(diǎn),Transformer 里因?yàn)槲恢眯畔⒅饕ㄟ^位置編碼來表示的,所以只要對(duì)應(yīng)的位置編碼不變,即使輸入向量順序再怎么變,對(duì) Transformer 來說還是差不多,這也是一些技巧如 PLM(Permutation Language Model)) 得以實(shí)現(xiàn)的原理。
直接這樣說太抽象了,舉個(gè)栗子。
假設(shè)一個(gè) batch 長(zhǎng)度不一樣樣本訓(xùn)練時(shí)如下
input_ids:
1 3 2 6 2 0 0
1 3 6 2 5 4 2
2 為分割和終止符,可看到訓(xùn)練時(shí),通過給第一句 padding,算 loss 時(shí) padding 位置都不算上來進(jìn)行訓(xùn)練。
而 inference 時(shí),只有 context,即使 padding 也會(huì)是下面這樣
1 3 2 0
1 3 6 2
這種情況下如果直接用默認(rèn)的 pos_ids 和 atten_mask (不了解的看The Annotated Transformer),第一句就會(huì)出現(xiàn)問題。
對(duì)比一下,訓(xùn)練時(shí)用到的三個(gè)參數(shù)
1 3 2 6 2 0 0 (input_ids)
0 1 2 3 4 0 0 (pos_ids)
1 1 1 1 1 0 0 (atten_mask)
訓(xùn)練時(shí)當(dāng)生成 6 的時(shí)候看到的是
1 3 2
0 1 2
1 1 1
再來看看生成時(shí)的情況,生成 6 的時(shí)候直接看到的是
1 3 2 0
0 1 2 3
1 1 1 0
首先拿的是最后 padding 位置的向量來預(yù)測(cè)下一個(gè),同時(shí)還有個(gè)問題就是,當(dāng)預(yù)測(cè)完成一個(gè)時(shí),之后拿到的位置 id 是不對(duì)的,這里假設(shè)預(yù)測(cè)成功為 6
1 3 2 0 6
0 1 2 3 4
1 1 1 0 1
會(huì)發(fā)現(xiàn)用 6 來預(yù)測(cè)下一個(gè)詞時(shí)已經(jīng)和訓(xùn)練時(shí)不一樣了,因?yàn)橛?xùn)練時(shí) 6 對(duì)應(yīng)的位置 id 是 3
實(shí)際這樣用時(shí),我也發(fā)現(xiàn)生成結(jié)果總是錯(cuò)開幾個(gè)字,像是給刀直接切開了一樣。
于是改進(jìn),最簡(jiǎn)單方法是直接給 padding 的位置向量都設(shè)成 padding 前的位置,這樣預(yù)測(cè)時(shí)位置向量就對(duì)了。
1 3 2 0 6
0 1 2 2 3
1 1 1 0 1
但這只解決了一個(gè)問題,即生成過程中的問題,第一個(gè)位置拿的還是 padding 位置進(jìn)行的輸出。這里有個(gè)解決方法,就是生成時(shí),第一次預(yù)測(cè)取到 padding 前 token,之后就依次取最后一個(gè)進(jìn)行預(yù)測(cè)了。
這樣基本上就算是解決問題了,但生成時(shí)第一次和之后還得區(qū)分開,說實(shí)話還是有點(diǎn) ugly.
還可以進(jìn)一步優(yōu)化。
Left Padding
解決方法很簡(jiǎn)單,思維掉轉(zhuǎn)下就好了,因?yàn)椴⑿猩蓵r(shí)都是從最后一位開始取,那能不能直接給 padding 放到前面去呢。
于是生成時(shí)一個(gè) batch 會(huì)變成這樣
input_ids:
0 0 1 3 2
1 3 6 2 5
那么對(duì)于第一條進(jìn)行預(yù)測(cè)時(shí),也只需要這樣設(shè)置一下 pos_id 和 atten_mask 就行
0 1 3 2
0 0 1 2
0 1 1 1
這樣子生成 6 時(shí),位置向量就能自然而然銜接上,同時(shí) atten_mask 也給前面的 padding 完美 mask 掉了。
完美解決!速度一下提高了好幾倍。
如果覺得有用,就請(qǐng)分享到朋友圈吧!
公眾號(hào)后臺(tái)回復(fù)“目標(biāo)檢測(cè)競(jìng)賽”獲取目標(biāo)檢測(cè)競(jìng)賽經(jīng)驗(yàn)資源~

# CV技術(shù)社群邀請(qǐng)函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)
即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~

