<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          BERT推理加速代碼

          共 2260字,需瀏覽 5分鐘

           ·

          2021-08-13 00:25

          LightSeq的BERT推理加速代碼,大家有需要的可以使用起來(lái)了。

          實(shí)現(xiàn)原理

          這里我直接使用預(yù)訓(xùn)練好的BERT模型,用戶(hù)只需要輸入一個(gè)帶有[MASK]標(biāo)記的句子,就可以自動(dòng)預(yù)測(cè)出完整的句子。

          例如我輸入“巴黎是[MASK]國(guó)的首都”,那么模型就會(huì)輸出“巴黎是法國(guó)的首都。”。

          LightSeq已經(jīng)「完美支持了BERT模型的快速推理」,代碼近期已經(jīng)開(kāi)源:


          GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generationgithub.com/bytedance/lightseqgithub.com/bytedance/lightseq


          BERT推理使用樣例可以參考examples/inference/python目錄下的ls_bert.py文件。我們用LightSeq來(lái)加速BERT推理試試。

          首先需要安裝LightSeq和Hugging Face:

          pip install lightseq transformers

          然后需要將Hugging Face的BERT模型導(dǎo)出為L(zhǎng)ightSeq支持的HDF5模型格式,運(yùn)行examples/inference/python目錄下的hf_bert_export.py文件即可,運(yùn)行前將代碼的第167-168兩行修改為下面這樣,指定使用中文版本的BERT預(yù)訓(xùn)練模型。

          output_lightseq_model_name = "lightseq-bert-base-chinese"
          input_huggingface_bert_model = "bert-base-chinese"

          然后就會(huì)在運(yùn)行目錄下生成一個(gè)lightseq-bert-base-chinese.hdf5模型文件,導(dǎo)出就成功啦。

          最后使用LightSeq進(jìn)行推理即可:

          import torch
          from transformers import AutoTokenizer, AutoModelForMaskedLM
          import lightseq.inference as lsi

          tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
          hf_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese")
          hf_model.to("cuda:0")
          ls_model = lsi.Bert("lightseq-bert-base-chinese.hdf5", 128)

          while True:
          raw_text = input("請(qǐng)輸入中文句子,要預(yù)測(cè)的字符用#代替:\n> ")
          input_text = raw_text.replace("#", "[MASK]")
          inputs = tokenizer(input_text, return_tensors="pt")
          input_ids = inputs["input_ids"]
          mask = inputs["attention_mask"]

          outputs = ls_model.infer(input_ids, mask)
          logits = hf_model.cls(torch.Tensor(outputs).to(dtype=torch.float, device="cuda:0"))
          output_ids = logits.argmax(axis=2)
          res_text = tokenizer.batch_decode(output_ids)

          res_text = res_text[0][1:-1].replace(" ", "")
          output_text = list(raw_text)
          for i in range(len(raw_text)):
          if raw_text[i] == "#":
          output_text[i] = res_text[i]
          print("> " + "".join(output_text))

          效果演示

          給大家看看效果,運(yùn)行我寫(xiě)好的代碼,我們來(lái)看看會(huì)輸出什么結(jié)果:

          請(qǐng)輸入中文句子,要預(yù)測(cè)的字符用#代替:
          > 巴黎是#國(guó)的首都。
          > 巴黎是法國(guó)的首都。

          代碼地址


          GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generationgithub.com/bytedance/lightseqgithub.com/bytedance/lightseq


          就在上周,首位外部貢獻(xiàn)者出現(xiàn)了,修復(fù)了LightSeq的詞嵌入表示的bug。

          在這里我們非常歡迎感興趣的同學(xué)來(lái)貢獻(xiàn)自己的代碼,包括但不局限于:修復(fù)bug、提供訓(xùn)練和推理樣例、支持更多模型結(jié)構(gòu)。


          瀏覽 90
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  四虎综合 | 91超碰在线免费观看 | 亚洲欧美日韩在线 | 日韩欧美精品 | 亚洲日本中文字幕乱码在线 |