實操教程|只用兩行代碼,我讓Transformer推理加速了50倍

極市導(dǎo)讀
本文介紹了一個Transformer系列模型推理加速庫--lightseq,僅需兩行代碼即可讓模型預(yù)測速度加速50倍,附有相關(guān)代碼。 >>加入極市CV技術(shù)交流群,走在計算機視覺的最前沿
最近有學(xué)妹問我,我訓(xùn)了一個Transformer模型,但是預(yù)測好慢啊,有啥解決方案嗎?
我心想,你又想好,又想快,咋不上天呢?

于是我跟她說,你可以試試lightseq啊,跟閃電??一樣快,用了你就可以上天了。
她一臉懵比,lightseq是啥玩意兒???咋就能讓我的模型起飛了呢?
我跟她說,你不需要知道太多細(xì)節(jié),你只需要知道它是一個Transformer系列模型推理加速庫就行了。
她還是一臉疑惑,那用起來能有huggingface方便嗎?你看人家就兩行代碼。
我不屑一笑,就這?lightseq也只要兩行代碼就夠了!

為了方便,我用了一個bart模型預(yù)測句子中mask單詞的例子來給她吹了一波。
不懂什么是bart?建議先去看看huggingface的文檔:
https://huggingface.co/transformers/model_doc/bart.html
huggingface bart
我們平時想用huggingface的bart來預(yù)測句子中的mask單詞,大體上都會像下面這樣寫代碼:
from transformers import BartTokenizer, BartForConditionalGenerationtokenizer = BartTokenizer.from_pretrained("facebook/bart-base")model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")sentences = ["I love that girl, but <mask> does not <mask> me."]inputs = tokenizer(sentences, return_tensors="pt", padding=True)generated_ids = model.generate(inputs["input_ids"], max_length=50)res = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)print(res)
當(dāng)然運行前要先安裝一下transformers包:
pip3 install transformers
最后會輸出句子“I love that girl, but she does not love me.”,句子中的兩個“mask”被預(yù)測成了“she”和“l(fā)ove”。
看起來預(yù)測的很nice,但是預(yù)測的也太慢了,這要是有一堆句子要去預(yù)測,不得等到猴年馬月?

接下來我們來看看lightseq是怎么加速預(yù)測的。
lightseq bart
代碼我都放在下面地址了,只要兩分鐘就能跑出結(jié)果了:
https://github.com/godweiyang/lightseq/tree/python_example/example/python
運行前要先安裝一下lightseq包:
pip3 install lightseq
首先lightseq只能接收Protocol Buffer協(xié)議定義的模型文件,如果你不知道這是啥也沒關(guān)系,因為我們幫你寫好了模型轉(zhuǎn)換的腳本,就是hf_bart_export.py,它會將huggingface預(yù)訓(xùn)練的bart模型轉(zhuǎn)換為transformer_pb2.py定義好的Protocol Buffer格式。
所以直接運行python3 hf_bart_export.py就行了,這里我們用的是bart-base模型。
運行完了會發(fā)現(xiàn)執(zhí)行目錄下多出一個lightseq_bart_base.pb文件,這就是轉(zhuǎn)換后的模型文件。
最后直接跟huggingface一樣,兩行代碼就能搞定啦:
import lightseqfrom transformers import BartTokenizertokenizer = BartTokenizer.from_pretrained("facebook/bart-base")model = lightseq.Transformer("lightseq_bart_base.pb", 128)sentences = ["I love that girl, but <mask> does not <mask> me."]inputs = tokenizer(sentences, return_tensors="pt", padding=True)generated_ids = model.infer(inputs["input_ids"])generated_ids = [ids[0] for ids in generated_ids[0]]res = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)print(res)
看得出來僅僅替換了模型定義和模型推理那兩行代碼而已,是不是非常簡單快速?

這時候她又問了,那我換一個模型,比如bert,要怎么導(dǎo)出pb模型呢?
也很簡單,只需要為bert也單獨寫一個hf_bert_export.py就行了。不過目前還在開發(fā)中,之后會慢慢完善常見的一些模型的。
速度到底怎么樣?
我寫好了一個例子,就在ls_bart.py里,直接運行就行了,當(dāng)然你也可以加上--user_input參數(shù)來手動輸入句子。
輸入的句子是:
I love that girl, but <mask> does not <mask> me.She is so <mask> that I can not help glance at <mask>.Nothing's gonna <mask> my love for you.Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.
運行結(jié)果如下:
=========================lightseq=========================lightseq generating...lightseq time: 0.034502994269132614slightseq results:I love that girl, but she does not love me.She is so beautiful that I can not help glance at her.Nothing's gonna change my love for you.Drop everything now. Meet me in the pouring rain. Kiss me on the sidewalk.=========================huggingface=========================huggingface generating...huggingface time: 1.6297104470431805shuggingface results:I love that girl, but she does not love me.She is so beautiful that I can not help glance at her.Nothing's gonna change my love for you.Drop everything now. Meet me in the pouring rain. Kiss me on the sidewalk.
可以看出預(yù)測的是真的牛批,最后兩句歌詞都預(yù)測的很完美,能看出是啥歌嗎?
再看預(yù)測時間,lightseq是huggingface的47倍左右,真是一個天上一個地下啊。

總結(jié)
總結(jié)一下,想要使用lightseq加速你的模型,只需要兩步就行了:
將你的模型轉(zhuǎn)換為pb格式的模型。(lightseq為你寫好了轉(zhuǎn)換腳本,不斷更新中) 調(diào)用 lightseq.Transformer和model.infer進(jìn)行快速推理。
學(xué)妹趕緊打住了我,好了好了,我知道很 了。還給你裝起來了,我這就去用。
但是源碼哪里有?我想學(xué)一學(xué)。
我又甩給她一串地址:

好好看,好好學(xué),都是CUDA寫的,要是看得迷糊,建議先去看看我之前的入門教程嗷:
godweiyang:熬了幾個通宵,我寫了份CUDA新手入門代碼
https://zhuanlan.zhihu.com/p/360441891
從此,世上又多了一位快如??的女孩。
推薦閱讀
2021-04-10
2021-04-02
2021-03-24

# CV技術(shù)社群邀請函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)
即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實項目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~

