<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch官宣:告別CUDA,GPU推理迎來Triton加速新時(shí)代

          共 5412字,需瀏覽 11分鐘

           ·

          2024-09-06 20:53



            新智元報(bào)道  

          編輯:喬楊 Frey
          【新智元導(dǎo)讀】用英偉達(dá)的GPU,但可以不用CUDA?PyTorch官宣,借助OpenAI開發(fā)的Triton語言編寫內(nèi)核來加速LLM推理,可以實(shí)現(xiàn)和CUDA類似甚至更佳的性能。

          試問,有多少機(jī)器學(xué)習(xí)小白曾被深度學(xué)習(xí)框架和CUDA的兼容問題所困擾?

          又有多少開發(fā)者曾因?yàn)轭l頻閃爍的警報(bào)「CUDA版本必須與安裝的PyTorch匹配?。。 苟髨D炸鍵盤?

          無論是TensorFlow還是Pytorch,GPU和CUDA搭配的概念早已深入骨髓。

          如果我說,就在昨天,有款為LLM「量身定做」的CUDA-free推理上新了!你激不激動(dòng)?

          原文地址:https://pytorch.org/blog/cuda-free-inference-for-llms/?hss_channel=tw-776585502606721024

          那么,讓我們緊跟Pytorch的官方技術(shù)博客,一探究竟!看看它是如何將「自由」變?yōu)楝F(xiàn)實(shí)!

          GPU的好搭子CUDA

          CUDA(Compute Unified Device Architecture)到底是何方神物?為何被視為GPU的好搭子,LLMs的「利器」?

          它是由英偉達(dá)開發(fā)的用于并行計(jì)算平臺(tái)和應(yīng)用程序的編程API,讓開發(fā)者能通過GPU開展高性能計(jì)算,包括:

          1. 多個(gè)能并行處理任務(wù)的核心,實(shí)現(xiàn)多線程

          2. 多種高效管理GPU內(nèi)存的方法,如全局內(nèi)存、共享內(nèi)存和常量內(nèi)存

          3. 創(chuàng)建并管理多條并行線程,提高數(shù)據(jù)處理效率

          4. 編譯器、調(diào)試器和性能分析工具組成的工具鏈,,幫助開發(fā)者優(yōu)化代碼

          簡而言之,CUDA使GPU加速LLM訓(xùn)練變?yōu)楝F(xiàn)實(shí),大幅縮短了訓(xùn)練時(shí)間。

          100%的Triton內(nèi)核

          Pytorch最近發(fā)表了一篇技術(shù)博客,他們以兩個(gè)模型——Llama3-8B和IBM的Granite-8B Code為例,100%使用Triton內(nèi)核實(shí)現(xiàn)了FP16推理。

          Granite-8B Code是由IBM開發(fā)的一種僅限解碼器的代碼模型,專為代碼生成任務(wù)設(shè)計(jì)。

          倉庫地址:https://huggingface.co/ibm-granite/granite-8b-code-base-4k

          值得注意的是,PyTorch指出他們實(shí)現(xiàn)了F16推理,也就是使用半精度浮點(diǎn)計(jì)算。

          FP32單精度浮點(diǎn)數(shù)

          F16半精度浮點(diǎn)數(shù)

          相對(duì)于FP32,使用FP16可以將位數(shù)減少一半,因而減少了所需內(nèi)存,允許使用更大的模型或更大的批大小,且數(shù)據(jù)傳輸速度更快。

          與F32相比,英偉達(dá)GPU提供的FP16將算術(shù)吞吐量提高了8倍,大幅加快了數(shù)學(xué)受限層的訓(xùn)練速度。

          此外,PyTorch團(tuán)隊(duì)還著重強(qiáng)調(diào),計(jì)算全部是依賴OpenAI的Triton語言執(zhí)行的。

          Triton是一種用于編寫高效自定義深度學(xué)習(xí)基元的語言和編譯器。

          Triton的開發(fā)者致力于建立一個(gè)開源環(huán)境,以比CUDA更高效地編寫代碼,同時(shí)也期望它比現(xiàn)有的特定領(lǐng)域語言(domain-specific language)更具靈活性。

          論文:https://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf

          倉庫:https://github.com/triton-lang/triton

          團(tuán)隊(duì)發(fā)現(xiàn),在英偉達(dá)H100上使用Triton內(nèi)核訓(xùn)練模型,性能可達(dá)CUDA內(nèi)核的76%~78%,在A100上也能達(dá)到62%~82%。

          既然相比CUDA有一定的性能損失,那為什么要全部使用Triton語言?

          PyTorch團(tuán)隊(duì)稱,Triton實(shí)現(xiàn)了LLM在GPU上的「可移植性」,能跨越多個(gè)不同個(gè)品牌的硬件,如英偉達(dá)、AMD、英特爾等。

          此外,它還在Python中為GPU編程提供了更高的「抽象層」,使開發(fā)者有機(jī)會(huì)編寫自定義的具備更高性能的內(nèi)核。

          最終,通過在H100和A100上使用Llama3-8B和Granite-8B的Triton和CUDA變體,并進(jìn)行推理階段的基準(zhǔn)測試,PyTorch團(tuán)隊(duì)證實(shí)了,Triton內(nèi)核能實(shí)現(xiàn)CUDA-Free的計(jì)算,且生成token的吞吐量有顯著提升。

          內(nèi)核架構(gòu)

          以Llama3為例,經(jīng)典的Transformer塊由一般由以下部分組成:

          其中涉及的核心操作包括:

          - RMS歸一化

          - 矩陣乘法:融合QKV矩陣

          - 旋轉(zhuǎn)位置編碼(RoPE)

          - Flash Attention

          - 矩陣乘法:投影為為輸出矩陣

          - RMS歸一化

          - 矩陣乘法:融合門控+向上投影

          - 激活函數(shù)SiLU

          - 逐元素(element-wise)矩陣乘法

          - 矩陣乘法:向下投影

          這些操作中都需要一個(gè)或多個(gè)GPU內(nèi)核進(jìn)行計(jì)算,雖然不同的Transformer模型的執(zhí)行細(xì)節(jié)可能有所不同,但核心操作是類似的。

          例如,與Llama 3不同,IBM的Granite 8B Code模型在MLP層中使用了bias,此類更改確實(shí)需要對(duì)內(nèi)核的修改。

          將這些Transformer塊堆疊在一起,再連接編碼層,就組成了一個(gè)經(jīng)典的Transformer模型。

          模型推理

          這些架構(gòu)代碼都會(huì)包含在model.py文件中,在PyTorch的eager執(zhí)行模式下,C會(huì)啟動(dòng)CUDA內(nèi)核執(zhí)行這些代碼。

          為了讓Llama3-8B和Granite-8B模型100%用Triton語言實(shí)現(xiàn)端到端推理,我們需要手寫Triton內(nèi)核(kernel),或利用torch.compile模塊自動(dòng)生成。

          對(duì)于較小的操作,比如 RMS歸一化、RoPE、SiLU函數(shù)和element-wise矩陣乘法,torch.compile可以自動(dòng)生成Triton內(nèi)核。

          使用Nsight等工具即可對(duì)這些內(nèi)核進(jìn)行觀察,如下圖所示,自動(dòng)生成的內(nèi)核顯示為QKV乘法和flash attention之前的深綠色方塊:

          使用torch.compile跟蹤 Llama3-8B,顯示CUDA內(nèi)核

          通過Nsight的跟蹤信息可以觀察到,在Llama3-8B中,占端到端延遲80%的兩個(gè)主要操作是矩陣乘法和注意力內(nèi)核,而且它們依舊由CUDA內(nèi)核操作。

          為了進(jìn)一步提升性能,我們開始手寫Triton內(nèi)核來替換上述兩個(gè)操作。

          手寫Triton內(nèi)核

          矩陣乘法

          對(duì)于線性層中的矩陣乘法,編寫一個(gè)自定義的 FP16 Triton GEMM (General Matrix-Matrix Multiply)內(nèi)核,執(zhí)行通用的矩陣-矩陣乘法,其中利用了SplitK進(jìn)行工作分解。

          為了實(shí)現(xiàn)最佳性能,還使用了窮舉搜索來調(diào)整SplitK GEMM內(nèi)核。

          因?yàn)槊總€(gè)線性層的權(quán)重矩陣都有不同的形狀,如果要獲得最佳性能,就需要針對(duì)每種矩陣形狀調(diào)整Triton內(nèi)核。

          Granite-8B和Llama3-8B的線性層權(quán)重矩陣規(guī)格如下:

          調(diào)整每個(gè)線性層后,相比未調(diào)整的Triton內(nèi)核,可以實(shí)現(xiàn)1.2倍的端到端加速。

          Flash Attention

          Triton的flash attention內(nèi)核有一系列不同的配置和實(shí)現(xiàn),包括:

          - AMD Flash

          - OpenAI Flash

          - Dao AI Lab Flash

          - XFormers Flash

          - PyTorch FlexAttention

          首先,采用eager模式,之后用torch.compile的標(biāo)準(zhǔn)方法進(jìn)行編譯,并對(duì)文本生成質(zhì)量進(jìn)行評(píng)估;

          上表總結(jié)了第2~5個(gè)內(nèi)核「開箱即用」時(shí)的表現(xiàn)。

          這些結(jié)果表明,如果目標(biāo)是構(gòu)建一個(gè)端到端的生產(chǎn)級(jí)內(nèi)核,那么擁有一個(gè)能跑基準(zhǔn)測試的內(nèi)核還遠(yuǎn)遠(yuǎn)不夠。

          后續(xù)測試中使用AMD flash attention內(nèi)核,因?yàn)樗梢酝ㄟ^torch.compile進(jìn)行編譯,且在eager和compile模式下都有清晰的輸出。

          為了滿足torch.compile與AMD flash attention內(nèi)核的兼容性,我們需要自定義torch運(yùn)算符,主要包括以下兩步:

          1. 將函數(shù)包裝到PyTorch自定義運(yùn)算符中

          2. 在運(yùn)算符中添加一個(gè)FakeTensor Kernel,給定flash輸入張量的形狀(q、k 和 v),它可以提供一種計(jì)算flash內(nèi)核輸出形狀的方法

          將模型中的運(yùn)算換為Triton的自定義內(nèi)核后,就能成功地進(jìn)行編譯和運(yùn)行,Nsight跟蹤信息如下圖所示:

          對(duì)比圖5可以發(fā)現(xiàn),圖6就是100%使用Triton內(nèi)核的前向計(jì)算。

          基準(zhǔn)測試

          基準(zhǔn)測試中使用Granite-8B和Llama3-8B模型,在英偉達(dá)H100和A100上進(jìn)行單GPU運(yùn)行,并定義了兩種不同的配置:

          Triton內(nèi)核配置使用:

          1. Triton SplitK GEMM

          2. AMD Triton Flash Attention

          CUDA 內(nèi)核配置使用:

          1. cuBLAS GEMM

          2. cuDNN Flash Attention - 縮放點(diǎn)積注意力 (SDPA)

          在典型的推理設(shè)置下,eager和torch編譯模式的吞吐量和token間延遲如下:

          批大小=2,輸入序列長度=512,輸出序列長度=25

          Triton模型在H100上的性能最高可達(dá)CUDA模型的78%,在A100上的性能最高可達(dá)82%。兩者間性能的差距可能源于矩陣乘法和flash attention的內(nèi)核延遲,下一節(jié)將詳細(xì)討論。

          微基準(zhǔn)測試

          解碼延遲時(shí)間對(duì)比,輸入是任意提示,批大小=1,提示長度=44

          將端到端推理中的各部分進(jìn)行單獨(dú)對(duì)比,我們注意到以下兩點(diǎn):

          1. Triton的matmul內(nèi)核比CUDA慢1.2~1.4倍

          2. AMD的Triton Flash Attention內(nèi)核比CUDA SDPA慢1.6倍

          這些結(jié)果表明,需要進(jìn)一步提升GEMM和Flash Attention等關(guān)鍵原語的內(nèi)核性能。

          比如最近提出的FlashAttention-3、FlexAttention等工作提供了更好的方法來利用底層硬件,有希望在此基礎(chǔ)上為Triton進(jìn)一步加速。

          將 FlexAttention與SDPA和AMD 的 Triton Flash內(nèi)核進(jìn)行比較,微基準(zhǔn)測試結(jié)果顯示,F(xiàn)lex有望被用于上下文更長、解碼規(guī)模更大的問題場景。

          英偉達(dá)H100 SXM5 80GB上的FlexAttention內(nèi)核基準(zhǔn)測試

          未來展望

          接下來,我們期望進(jìn)一步優(yōu)化矩陣乘法(matmuls),以更充分地利用硬件。

          比如使用不同的工作分解方法(類似StreamK的持久內(nèi)核技術(shù)),以加快基于Triton的方法。

          我們還期望繼續(xù)探索FlexAttention和FlashAttention-3,進(jìn)一步縮小Triton和CUDA間的差距。

          以上的實(shí)驗(yàn)只針對(duì)FP16精度,但早前的研究表明,與cuBLAS FP8 GEMM相比,F(xiàn)P8 Triton GEMM內(nèi)核表現(xiàn)更好。因此接下來的工作還會(huì)探討端到端FP8 LLM推理。

          參考資料:

          https://pytorch.org/blog/cuda-free-inference-for-llms/?utm_content=306418723&utm_medium=social&utm_source=twitter&hss_channel=tw-776585502606721024




          瀏覽 48
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  人人摸人人摸人人摸 | 免费A片视频在线观看 | 色综合久久天天综合网 | 久久逼网| 波霸巨大乳一区二区三区 |