PyTorch官宣:告別CUDA,GPU推理迎來Triton加速新時(shí)代
共 5412字,需瀏覽 11分鐘
·
2024-09-06 20:53
新智元報(bào)道
新智元報(bào)道
【新智元導(dǎo)讀】用英偉達(dá)的GPU,但可以不用CUDA?PyTorch官宣,借助OpenAI開發(fā)的Triton語言編寫內(nèi)核來加速LLM推理,可以實(shí)現(xiàn)和CUDA類似甚至更佳的性能。
又有多少開發(fā)者曾因?yàn)轭l頻閃爍的警報(bào)「CUDA版本必須與安裝的PyTorch匹配?。。 苟髨D炸鍵盤?
無論是TensorFlow還是Pytorch,GPU和CUDA搭配的概念早已深入骨髓。
如果我說,就在昨天,有款為LLM「量身定做」的CUDA-free推理上新了!你激不激動(dòng)?
原文地址:https://pytorch.org/blog/cuda-free-inference-for-llms/?hss_channel=tw-776585502606721024
那么,讓我們緊跟Pytorch的官方技術(shù)博客,一探究竟!看看它是如何將「自由」變?yōu)楝F(xiàn)實(shí)!
GPU的好搭子CUDA
CUDA(Compute Unified Device Architecture)到底是何方神物?為何被視為GPU的好搭子,LLMs的「利器」?
它是由英偉達(dá)開發(fā)的用于并行計(jì)算平臺(tái)和應(yīng)用程序的編程API,讓開發(fā)者能通過GPU開展高性能計(jì)算,包括:
1. 多個(gè)能并行處理任務(wù)的核心,實(shí)現(xiàn)多線程
2. 多種高效管理GPU內(nèi)存的方法,如全局內(nèi)存、共享內(nèi)存和常量內(nèi)存
3. 創(chuàng)建并管理多條并行線程,提高數(shù)據(jù)處理效率
4. 編譯器、調(diào)試器和性能分析工具組成的工具鏈,,幫助開發(fā)者優(yōu)化代碼
簡而言之,CUDA使GPU加速LLM訓(xùn)練變?yōu)楝F(xiàn)實(shí),大幅縮短了訓(xùn)練時(shí)間。
100%的Triton內(nèi)核
Pytorch最近發(fā)表了一篇技術(shù)博客,他們以兩個(gè)模型——Llama3-8B和IBM的Granite-8B Code為例,100%使用Triton內(nèi)核實(shí)現(xiàn)了FP16推理。
Granite-8B Code是由IBM開發(fā)的一種僅限解碼器的代碼模型,專為代碼生成任務(wù)設(shè)計(jì)。
倉庫地址:https://huggingface.co/ibm-granite/granite-8b-code-base-4k
值得注意的是,PyTorch指出他們實(shí)現(xiàn)了F16推理,也就是使用半精度浮點(diǎn)計(jì)算。
FP32單精度浮點(diǎn)數(shù)
F16半精度浮點(diǎn)數(shù)
相對(duì)于FP32,使用FP16可以將位數(shù)減少一半,因而減少了所需內(nèi)存,允許使用更大的模型或更大的批大小,且數(shù)據(jù)傳輸速度更快。
與F32相比,英偉達(dá)GPU提供的FP16將算術(shù)吞吐量提高了8倍,大幅加快了數(shù)學(xué)受限層的訓(xùn)練速度。
此外,PyTorch團(tuán)隊(duì)還著重強(qiáng)調(diào),計(jì)算全部是依賴OpenAI的Triton語言執(zhí)行的。
Triton是一種用于編寫高效自定義深度學(xué)習(xí)基元的語言和編譯器。
Triton的開發(fā)者致力于建立一個(gè)開源環(huán)境,以比CUDA更高效地編寫代碼,同時(shí)也期望它比現(xiàn)有的特定領(lǐng)域語言(domain-specific language)更具靈活性。
論文:https://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf
倉庫:https://github.com/triton-lang/triton
團(tuán)隊(duì)發(fā)現(xiàn),在英偉達(dá)H100上使用Triton內(nèi)核訓(xùn)練模型,性能可達(dá)CUDA內(nèi)核的76%~78%,在A100上也能達(dá)到62%~82%。
既然相比CUDA有一定的性能損失,那為什么要全部使用Triton語言?
PyTorch團(tuán)隊(duì)稱,Triton實(shí)現(xiàn)了LLM在GPU上的「可移植性」,能跨越多個(gè)不同個(gè)品牌的硬件,如英偉達(dá)、AMD、英特爾等。
此外,它還在Python中為GPU編程提供了更高的「抽象層」,使開發(fā)者有機(jī)會(huì)編寫自定義的具備更高性能的內(nèi)核。
最終,通過在H100和A100上使用Llama3-8B和Granite-8B的Triton和CUDA變體,并進(jìn)行推理階段的基準(zhǔn)測試,PyTorch團(tuán)隊(duì)證實(shí)了,Triton內(nèi)核能實(shí)現(xiàn)CUDA-Free的計(jì)算,且生成token的吞吐量有顯著提升。
內(nèi)核架構(gòu)
以Llama3為例,經(jīng)典的Transformer塊由一般由以下部分組成:
其中涉及的核心操作包括:
- RMS歸一化
- 矩陣乘法:融合QKV矩陣
- 旋轉(zhuǎn)位置編碼(RoPE)
- Flash Attention
- 矩陣乘法:投影為為輸出矩陣
- RMS歸一化
- 矩陣乘法:融合門控+向上投影
- 激活函數(shù)SiLU
- 逐元素(element-wise)矩陣乘法
- 矩陣乘法:向下投影
這些操作中都需要一個(gè)或多個(gè)GPU內(nèi)核進(jìn)行計(jì)算,雖然不同的Transformer模型的執(zhí)行細(xì)節(jié)可能有所不同,但核心操作是類似的。
例如,與Llama 3不同,IBM的Granite 8B Code模型在MLP層中使用了bias,此類更改確實(shí)需要對(duì)內(nèi)核的修改。
將這些Transformer塊堆疊在一起,再連接編碼層,就組成了一個(gè)經(jīng)典的Transformer模型。
模型推理
這些架構(gòu)代碼都會(huì)包含在model.py文件中,在PyTorch的eager執(zhí)行模式下,C會(huì)啟動(dòng)CUDA內(nèi)核執(zhí)行這些代碼。
為了讓Llama3-8B和Granite-8B模型100%用Triton語言實(shí)現(xiàn)端到端推理,我們需要手寫Triton內(nèi)核(kernel),或利用torch.compile模塊自動(dòng)生成。
對(duì)于較小的操作,比如 RMS歸一化、RoPE、SiLU函數(shù)和element-wise矩陣乘法,torch.compile可以自動(dòng)生成Triton內(nèi)核。
使用Nsight等工具即可對(duì)這些內(nèi)核進(jìn)行觀察,如下圖所示,自動(dòng)生成的內(nèi)核顯示為QKV乘法和flash attention之前的深綠色方塊:
使用torch.compile跟蹤 Llama3-8B,顯示CUDA內(nèi)核
通過Nsight的跟蹤信息可以觀察到,在Llama3-8B中,占端到端延遲80%的兩個(gè)主要操作是矩陣乘法和注意力內(nèi)核,而且它們依舊由CUDA內(nèi)核操作。
為了進(jìn)一步提升性能,我們開始手寫Triton內(nèi)核來替換上述兩個(gè)操作。
手寫Triton內(nèi)核
矩陣乘法
對(duì)于線性層中的矩陣乘法,編寫一個(gè)自定義的 FP16 Triton GEMM (General Matrix-Matrix Multiply)內(nèi)核,執(zhí)行通用的矩陣-矩陣乘法,其中利用了SplitK進(jìn)行工作分解。
為了實(shí)現(xiàn)最佳性能,還使用了窮舉搜索來調(diào)整SplitK GEMM內(nèi)核。
因?yàn)槊總€(gè)線性層的權(quán)重矩陣都有不同的形狀,如果要獲得最佳性能,就需要針對(duì)每種矩陣形狀調(diào)整Triton內(nèi)核。
Granite-8B和Llama3-8B的線性層權(quán)重矩陣規(guī)格如下:
調(diào)整每個(gè)線性層后,相比未調(diào)整的Triton內(nèi)核,可以實(shí)現(xiàn)1.2倍的端到端加速。
Flash Attention
Triton的flash attention內(nèi)核有一系列不同的配置和實(shí)現(xiàn),包括:
- AMD Flash
- OpenAI Flash
- Dao AI Lab Flash
- XFormers Flash
- PyTorch FlexAttention
首先,采用eager模式,之后用torch.compile的標(biāo)準(zhǔn)方法進(jìn)行編譯,并對(duì)文本生成質(zhì)量進(jìn)行評(píng)估;
上表總結(jié)了第2~5個(gè)內(nèi)核「開箱即用」時(shí)的表現(xiàn)。
這些結(jié)果表明,如果目標(biāo)是構(gòu)建一個(gè)端到端的生產(chǎn)級(jí)內(nèi)核,那么擁有一個(gè)能跑基準(zhǔn)測試的內(nèi)核還遠(yuǎn)遠(yuǎn)不夠。
后續(xù)測試中使用AMD flash attention內(nèi)核,因?yàn)樗梢酝ㄟ^torch.compile進(jìn)行編譯,且在eager和compile模式下都有清晰的輸出。
為了滿足torch.compile與AMD flash attention內(nèi)核的兼容性,我們需要自定義torch運(yùn)算符,主要包括以下兩步:
1. 將函數(shù)包裝到PyTorch自定義運(yùn)算符中
2. 在運(yùn)算符中添加一個(gè)FakeTensor Kernel,給定flash輸入張量的形狀(q、k 和 v),它可以提供一種計(jì)算flash內(nèi)核輸出形狀的方法
將模型中的運(yùn)算換為Triton的自定義內(nèi)核后,就能成功地進(jìn)行編譯和運(yùn)行,Nsight跟蹤信息如下圖所示:
對(duì)比圖5可以發(fā)現(xiàn),圖6就是100%使用Triton內(nèi)核的前向計(jì)算。
基準(zhǔn)測試
基準(zhǔn)測試中使用Granite-8B和Llama3-8B模型,在英偉達(dá)H100和A100上進(jìn)行單GPU運(yùn)行,并定義了兩種不同的配置:
Triton內(nèi)核配置使用:
1. Triton SplitK GEMM
2. AMD Triton Flash Attention
CUDA 內(nèi)核配置使用:
1. cuBLAS GEMM
2. cuDNN Flash Attention - 縮放點(diǎn)積注意力 (SDPA)
在典型的推理設(shè)置下,eager和torch編譯模式的吞吐量和token間延遲如下:
批大小=2,輸入序列長度=512,輸出序列長度=25
Triton模型在H100上的性能最高可達(dá)CUDA模型的78%,在A100上的性能最高可達(dá)82%。兩者間性能的差距可能源于矩陣乘法和flash attention的內(nèi)核延遲,下一節(jié)將詳細(xì)討論。
微基準(zhǔn)測試
解碼延遲時(shí)間對(duì)比,輸入是任意提示,批大小=1,提示長度=44
將端到端推理中的各部分進(jìn)行單獨(dú)對(duì)比,我們注意到以下兩點(diǎn):
1. Triton的matmul內(nèi)核比CUDA慢1.2~1.4倍
2. AMD的Triton Flash Attention內(nèi)核比CUDA SDPA慢1.6倍
這些結(jié)果表明,需要進(jìn)一步提升GEMM和Flash Attention等關(guān)鍵原語的內(nèi)核性能。
比如最近提出的FlashAttention-3、FlexAttention等工作提供了更好的方法來利用底層硬件,有希望在此基礎(chǔ)上為Triton進(jìn)一步加速。
將 FlexAttention與SDPA和AMD 的 Triton Flash內(nèi)核進(jìn)行比較,微基準(zhǔn)測試結(jié)果顯示,F(xiàn)lex有望被用于上下文更長、解碼規(guī)模更大的問題場景。
英偉達(dá)H100 SXM5 80GB上的FlexAttention內(nèi)核基準(zhǔn)測試
未來展望
接下來,我們期望進(jìn)一步優(yōu)化矩陣乘法(matmuls),以更充分地利用硬件。
比如使用不同的工作分解方法(類似StreamK的持久內(nèi)核技術(shù)),以加快基于Triton的方法。
我們還期望繼續(xù)探索FlexAttention和FlashAttention-3,進(jìn)一步縮小Triton和CUDA間的差距。
以上的實(shí)驗(yàn)只針對(duì)FP16精度,但早前的研究表明,與cuBLAS FP8 GEMM相比,F(xiàn)P8 Triton GEMM內(nèi)核表現(xiàn)更好。因此接下來的工作還會(huì)探討端到端FP8 LLM推理。
https://pytorch.org/blog/cuda-free-inference-for-llms/?utm_content=306418723&utm_medium=social&utm_source=twitter&hss_channel=tw-776585502606721024
