轉(zhuǎn)自:新智元
眾所周知,PyTorch和TensorFlow是兩個(gè)非常受歡迎的深度學(xué)習(xí)框架。12月2日,英偉達(dá)發(fā)布了最新的TensorRT 8.2版本,對(duì)10億級(jí)參數(shù)的NLP模型進(jìn)行了優(yōu)化,其中就包括用于翻譯和文本生成的T5和GPT-2。而這一次,TensorRT讓實(shí)時(shí)運(yùn)行NLP應(yīng)用程序成為可能。TensorRT是一個(gè)高性能的深度學(xué)習(xí)推理優(yōu)化器,讓AI應(yīng)用擁有低延遲、高吞吐量的推理能力。新的TensorRT框架為PyTorch和TensorFlow提供了簡(jiǎn)單的API,帶來(lái)強(qiáng)大的FP16和INT8優(yōu)化功能。只需一行代碼,調(diào)用一個(gè)簡(jiǎn)單的API,模型在NVIDIA GPU上就能實(shí)現(xiàn)高達(dá)6倍的性能提升。Torch-TensorRT編譯器的架構(gòu)由三個(gè)階段組成:簡(jiǎn)化TorchScript模塊
轉(zhuǎn)換
執(zhí)行
Torch-TensorRT可以將常見(jiàn)操作直接映射到TensorRT上。值得注意的是,這種過(guò)程并不影響計(jì)算圖本身的功能。Torch-TensorRT自動(dòng)識(shí)別與TensorRT兼容的子圖,并將它們翻譯成TensorRT操作:具有靜態(tài)值的節(jié)點(diǎn)被評(píng)估并映射到常數(shù)。
描述張量計(jì)算的節(jié)點(diǎn)被轉(zhuǎn)換為一個(gè)或多個(gè)TensorRT層。
剩下的節(jié)點(diǎn)留在TorchScript中,形成一個(gè)混合圖,并作為標(biāo)準(zhǔn)的TorchScript模塊返回。
修改后的模塊會(huì)在嵌入TensorRT引擎后返回,也就是說(shuō)整個(gè)模型,包括PyTorch代碼、模型權(quán)重和TensorRT引擎,都可以在一個(gè)包中進(jìn)行移植。將Conv2d層轉(zhuǎn)化為T(mén)ensorRT引擎,而log_sigmoid則回到TorchScript JIT中當(dāng)執(zhí)行編譯模塊時(shí),TorchScript解釋器會(huì)調(diào)用TensorRT引擎并傳遞所有輸入。之后,TensorRT會(huì)將結(jié)果推送回解釋器,整個(gè)流程和使用普通的TorchScript模塊別無(wú)二致。PyTorch和TensorRT操作的運(yùn)行時(shí)執(zhí)行Torch-TensorRT通過(guò)兩種技術(shù)增強(qiáng)了對(duì)低精度推理的支持:訓(xùn)練后量化(PTQ)
量化感知訓(xùn)練(QAT)
對(duì)于PTQ來(lái)說(shuō),TensorRT用目標(biāo)領(lǐng)域的樣本數(shù)據(jù)訓(xùn)練模型,同時(shí)跟蹤FP32精度下的權(quán)重激活,以校準(zhǔn)FP32到INT8的映射,使FP32和INT8推理之間的信息損失最小。英偉達(dá)的安培架構(gòu)在A100 GPU上引入了第三代張量核心,可以在網(wǎng)絡(luò)權(quán)重中增加細(xì)粒度的稀疏性。因此,A100在提供最大吞吐量的同時(shí),也不會(huì)犧牲深度學(xué)習(xí)核心的矩陣乘法累積工作的準(zhǔn)確性。TensorRT支持在Tensor Core上執(zhí)行深度學(xué)習(xí)模型的稀疏層,而Torch-TensorRT將這種稀疏支持?jǐn)U展到卷積和全連接層。比如,用EfficientNet圖像分類(lèi)模型進(jìn)行推理,并計(jì)算PyTorch模型和經(jīng)過(guò)Torch-TensorRT優(yōu)化的模型的吞吐量。以下是在NVIDIA A100 GPU上取得的結(jié)果,batch size為1。在NVIDIA A100 GPU上比較原生PyTorch和Torch-TensorRt的吞吐量用TensorRT實(shí)現(xiàn)T5和GPT-2實(shí)時(shí)推理
Transformer架構(gòu)完全改變了自然語(yǔ)言處理領(lǐng)域。近年來(lái),許多新穎的大語(yǔ)言模型都建立在Transformer模塊之上,比如BERT、GPT和T5。T5可以用來(lái)回答問(wèn)題、做總結(jié)、翻譯文本和分類(lèi)文本。T5(Text-To-Text Transfer Transformer,文本到文本轉(zhuǎn)換Transformer)是谷歌創(chuàng)建的架構(gòu)。它將所有自然語(yǔ)言處理(NLP)任務(wù)重新組織成統(tǒng)一的文本到文本格式,其中輸入和輸出總是文本字符串。T5的架構(gòu)能夠?qū)⑾嗤哪P?、損失函數(shù)和超參數(shù)應(yīng)用于任何自然語(yǔ)言處理任務(wù),如機(jī)器翻譯、文檔摘要、問(wèn)題回答和分類(lèi)任務(wù),如情感分析。T5模型的靈感來(lái)自于一個(gè)NLP領(lǐng)域的共識(shí),即遷移學(xué)習(xí)已經(jīng)在自然語(yǔ)言處理中取得了最先進(jìn)的結(jié)果。遷移學(xué)習(xí)背后的原理是,在大量可用的未標(biāo)記數(shù)據(jù)上經(jīng)過(guò)預(yù)訓(xùn)練的模型,可以在較小的特定任務(wù)的已標(biāo)記數(shù)據(jù)集上進(jìn)行針對(duì)性的微調(diào)。事實(shí)證明,預(yù)訓(xùn)練-微調(diào)模型比從頭開(kāi)始在特定任務(wù)數(shù)據(jù)集上訓(xùn)練的模型具有更好的結(jié)果。T5模型在許多下游自然語(yǔ)言處理任務(wù)上獲得了最先進(jìn)的結(jié)果。已發(fā)布的預(yù)訓(xùn)練T5的參數(shù)最多高達(dá)3B和11B。雖說(shuō)都是語(yǔ)言模型,GPT-2的長(zhǎng)處在于生成優(yōu)秀的文本。GPT-2(Generative Pre-Trained Transformer 2)是一種自回歸無(wú)監(jiān)督語(yǔ)言模型,最初由OpenAI提出。它是由transformer解碼器塊構(gòu)建的,并在非常大的文本語(yǔ)料庫(kù)上進(jìn)行訓(xùn)練,以預(yù)測(cè)文本的下一個(gè)單詞。已發(fā)布的GPT-2模型中,最大的擁有1.5B參數(shù),能夠?qū)懗龇浅_B貫的文本。雖然較大的神經(jīng)語(yǔ)言模型通常會(huì)產(chǎn)生更好的結(jié)果,但將其部署到生產(chǎn)中會(huì)帶來(lái)很大的挑戰(zhàn),尤其是對(duì)于在線(xiàn)應(yīng)用程序,幾十毫秒的額外延遲足以讓用戶(hù)的體驗(yàn)變差很多。借助最新的TensorRT 8.2,英偉達(dá)針對(duì)大模型的實(shí)時(shí)推斷這一需求,優(yōu)化了T5和GPT-2。首先,從Hugging Face模型中心下載Hugging Face PyTorch T5模型及其相關(guān)的tokenizer。T5_VARIANT = 't5-small't5_model = T5ForConditionalGeneration.from_pretrained(T5_VARIANT)tokenizer = T5Tokenizer.from_pretrained(T5_VARIANT)config = T5Config(T5_VARIANT)
接下來(lái),將模型轉(zhuǎn)換為經(jīng)過(guò)優(yōu)化的TensorRT執(zhí)行引擎。不過(guò),在將T5模型轉(zhuǎn)換為T(mén)ensorRT引擎之前,需要將PyTorch模型轉(zhuǎn)換為一種中間通用格式:ONNX。ONNX是機(jī)器學(xué)習(xí)和深度學(xué)習(xí)模型的開(kāi)放格式。它能夠?qū)⑸疃葘W(xué)習(xí)和機(jī)器學(xué)習(xí)模型從不同的框架(如TensorFlow、PyTorch、MATLAB、Caffe和Keras)轉(zhuǎn)換為一個(gè)統(tǒng)一的格式。encoder_onnx_model_fpath = T5_VARIANT + "-encoder.onnx"decoder_onnx_model_fpath = T5_VARIANT + "-decoder-with-lm-head.onnx"t5_encoder = T5EncoderTorchFile(t5_model.to('cpu'), metadata)t5_decoder = T5DecoderTorchFile(t5_model.to('cpu'), metadata)onnx_t5_encoder = t5_encoder.as_onnx_model( os.path.join(onnx_model_path, encoder_onnx_model_fpath), force_overwrite=False)onnx_t5_decoder = t5_decoder.as_onnx_model( os.path.join(onnx_model_path, decoder_onnx_model_fpath), force_overwrite=False)
然后,將準(zhǔn)備好的T5 ONNX編碼器和解碼器轉(zhuǎn)換為優(yōu)化的TensorRT引擎。由于TensorRT執(zhí)行了許多優(yōu)化,例如融合操作、消除轉(zhuǎn)置操作和內(nèi)核自動(dòng)調(diào)整(在目標(biāo)GPU架構(gòu)上找到性能最佳的內(nèi)核),因此這一轉(zhuǎn)換過(guò)程可能需要一段時(shí)間。t5_trt_encoder_engine = T5EncoderONNXt5_trt_encoder_engine = T5EncoderONNXFile( os.path.join(onnx_model_path, encoder_onnx_model_fpath), metadata ).as_trt_engine(os.path.join(tensorrt_model_path, encoder_onnx_model_fpath) + ".engine")t5_trt_decoder_engine = T5DecoderONNXFile( os.path.join(onnx_model_path, decoder_onnx_model_fpath), metadata ).as_trt_engine(os.path.join(tensorrt_model_path, decoder_onnx_model_fpath) + ".engine")
最后,就可以用T5的TensorRT引擎進(jìn)行推理了。t5_trt_encoder = T5TRTEncoder( t5_trt_encoder_engine, metadata, tfm_config )t5_trt_decoder = T5TRTDecoder( t5_trt_decoder_engine, metadata, tfm_config )#generate outputencoder_last_hidden_state = t5_trt_encoder(input_ids=input_ids)outputs = t5_trt_decoder.greedy_search( input_ids=decoder_input_ids, encoder_hidden_states=encoder_last_hidden_state, stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length)]) )print(tokenizer.decode(outputs[0], skip_special_tokens=True))
同樣,對(duì)于GPT-2模型也可以按照相同的過(guò)程生成一個(gè)TensorRT引擎。優(yōu)化后的TensorRT引擎可以在HuggingFace推理工作流中替代原始的PyTorch模型。TensorRT vs PyTorch CPU、PyTorch GPU
通過(guò)將T5或GPT-2轉(zhuǎn)變?yōu)門(mén)ensorRT引擎,與PyTorch模型在GPU上的推斷時(shí)間相比,TensorRT的延遲降低了3至6倍,與PyTorch模型在CPU上的推斷時(shí)間相比,延遲更是降低了9至21倍。與PyTorch模型在CPU上的推斷時(shí)間相比,運(yùn)行在A100 GPU上的TensorRT引擎將延遲縮小了21倍。對(duì)NLP感興趣的朋友,要是想加速大語(yǔ)言模型的推理過(guò)程,就快來(lái)試試TensorRT 8.2吧!
參考資料:
https://developer.nvidia.com/blog/nvidia-announces-tensorrt-8-2-and-integrations-with-pytorch-and-tensorflow/?ncid=so-twit-314589#cid=dl13_so-twit_en-us
https://developer.nvidia.com/blog/accelerating-inference-up-to-6x-faster-in-pytorch-with-torch-tensorrt/
https://developer.nvidia.com/blog/optimizing-t5-and-gpt-2-for-real-time-inference-with-tensorrt/