Mamba一作再祭神作,H100利用率飆至75%!FlashAttention三代性能翻倍,比標(biāo)準(zhǔn)注意力快16倍
共 6742字,需瀏覽 14分鐘
·
2024-07-12 22:00
極市導(dǎo)讀
時(shí)隔一年,F(xiàn)lashAttention又推出了第三代更新,專(zhuān)門(mén)針對(duì)H100 GPU的新特性進(jìn)行優(yōu)化,在之前的基礎(chǔ)上又實(shí)現(xiàn)了1.5~2倍的速度提升。>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿
FlashAttention又有后續(xù)了!
去年7月,F(xiàn)lashAttention-2發(fā)布,相比第一代實(shí)現(xiàn)了2倍的速度提升,比PyTorch上的標(biāo)準(zhǔn)注意力操作快5~9倍,達(dá)到A100上理論最大FLOPS的50~73%,實(shí)際訓(xùn)練速度可達(dá)225 TFLOPS(模型FLOPs利用率為72%)。
然而,去年發(fā)布FlashAttenion-2尚未運(yùn)用到硬件中的最新功能,在H100上僅實(shí)現(xiàn)了理論最大FLOPS 35%的利用率。
時(shí)隔一年,F(xiàn)lashAttention-3歸來(lái),將H100的FLOP利用率再次拉到75%,相比第二代又實(shí)現(xiàn)了1.5~2倍的速度提升,在H100上的速度達(dá)到740 TFLOPS。
論文地址:https://tridao.me/publications/flash3/flash3.pdf
值得一提的是,F(xiàn)lashAttention v1和v2的第一作者也是Mamba的共同一作,普林斯頓大學(xué)助理教授Tri Dao,他的名字也在這次FlashAttention-3的作者列表中。
Tri Dao師從于Christopher Ré和Stefano Ermon,去年6月在斯坦福大學(xué)獲得計(jì)算機(jī)博士學(xué)位,畢業(yè)后擔(dān)任Together AI的首席科學(xué)家,并從今年6月開(kāi)始入職普林斯頓大學(xué)。
用最新最強(qiáng)的GPU,達(dá)到超高的算力利用率,這下LLM的性能和上下文長(zhǎng)度又要迎來(lái)一波暴漲了。
PyTorch官方也在推特上轉(zhuǎn)發(fā)了這個(gè)消息,想必我們能在不久后看到FlashAttention被集成到PyTorch中。
目前論文還未上傳到arxiv平臺(tái),只發(fā)表在Tri Dao本人的博客中,但GitHub上已經(jīng)發(fā)布了用于Beta測(cè)試的源代碼。
項(xiàng)目地址:https://github.com/Dao-AILab/flash-attention
網(wǎng)友在其中發(fā)現(xiàn)了重要的華點(diǎn)——這一版的FlashAttention專(zhuān)攻H100 GPU,只能在H100或H800上運(yùn)行,不支持其他GPU型號(hào)。
所以即使有了源代碼,大多數(shù)只有4090的開(kāi)發(fā)者也應(yīng)該運(yùn)行不起來(lái),還得先攢錢(qián)買(mǎi)H100。
面對(duì)這篇論文,財(cái)大氣粗的科技巨頭可以說(shuō),「太棒了,現(xiàn)在我們99999個(gè)H100的算力集群還能更快。」
普通研發(fā)人員只能說(shuō),「啊,得買(mǎi)幾個(gè)H100來(lái)試試,但不幸的是,我只有2個(gè)腎。」
H100利用率飆至75%,LLM速度再翻倍
對(duì)Transformer架構(gòu)來(lái)說(shuō),注意力機(jī)制既是核心優(yōu)勢(shì),也是重要瓶頸。其理論計(jì)算量是序列長(zhǎng)度的二次方,因此拖慢了計(jì)算速度,阻礙了在LLM中的長(zhǎng)上下文應(yīng)用。
FlashAttention(以及FlashAttention-2)通過(guò)減少內(nèi)存讀寫(xiě)次數(shù),開(kāi)創(chuàng)了一種在GPU上加速注意力機(jī)制的方法,現(xiàn)在大多數(shù)庫(kù)都使用它來(lái)加速Transformer的訓(xùn)練和推理。
這使得大語(yǔ)言模型的上下文長(zhǎng)度在過(guò)去兩年中大幅增加,從2-4K(如GPT-3、OPT)擴(kuò)展到128K(如GPT-4),甚至達(dá)到1M(如Llama 3、Gemini 1.5 Pro)。
然而,盡管取得了顯著進(jìn)展,F(xiàn)lashAttention還沒(méi)有充分利用現(xiàn)代硬件的新功能,F(xiàn)lashAttention-2在H100 GPU上僅實(shí)現(xiàn)了理論最大FLOPs的35%利用率。
針對(duì)最新的Hopper GPU進(jìn)行改進(jìn),F(xiàn)lashAttention-3主要使用了如下3種技術(shù)加速注意力機(jī)制:利用Tensor Cores和TMA的異步性——
1)通過(guò)warp-specialization技術(shù)重疊整體計(jì)算和數(shù)據(jù)移動(dòng);
2)其次,交替進(jìn)行塊狀矩陣乘法和softmax操作;
3)利用硬件支持進(jìn)行FP8低精度的非相干處理。
在FP16模式下,F(xiàn)lashAttention-3比FlashAttention-2快1.5~2倍,達(dá)到740 TFLOPS,即H100理論最大FLOPs的75%。
在FP8模式下,F(xiàn)lashAttention-3接近1.2 PFLOPS,誤差比基線FP8注意力小2.6倍。
FlashAttention-3的改進(jìn)將帶來(lái)以下變化:
-
更高效的GPU利用率:新技術(shù)使H100 GPU的利用率從之前的35%提升到75%。這使得LLM的訓(xùn)練和運(yùn)行速度顯著提高,達(dá)到了之前版本的1.5~2倍。
-
更好的低精度性能:FlashAttention-3在保持準(zhǔn)確性的同時(shí),可以使用FP8這樣的較低精度。這不僅加快了處理速度,還能減少內(nèi)存使用,從而為運(yùn)行大規(guī)模AI操作的客戶(hù)節(jié)省成本并提高效率。
-
在LLMs中使用更長(zhǎng)上下文的能力:通過(guò)加速注意力機(jī)制,F(xiàn)lashAttention-3使AI模型能夠更高效地處理更長(zhǎng)的文本。這意味著應(yīng)用程序可以理解和生成更長(zhǎng)、更復(fù)雜的內(nèi)容,而不會(huì)影響速度。
FlashAttention回顧
FlashAttention是一種對(duì)注意力計(jì)算進(jìn)行重新排序的算法,利用分塊和重計(jì)算技術(shù),大大加快了計(jì)算速度,并將內(nèi)存使用量從序列長(zhǎng)度的二次方減少到線性。
利用分塊技術(shù),將輸入數(shù)據(jù)塊從GPU內(nèi)存中的HBM(高速帶寬緩存)加載到SRAM中,對(duì)其進(jìn)行注意力計(jì)算,然后在HBM中更新輸出。
這種方法不將計(jì)算過(guò)程中的大型注意力矩陣寫(xiě)入HBM,減少了內(nèi)存的讀寫(xiě)總量,從而實(shí)現(xiàn)了2~4倍的速度提升。
下面是FlashAttention前向傳遞的示意圖:通過(guò)分塊和softmax重新縮放,以塊為單位進(jìn)行操作,避免了從HBM中頻繁讀寫(xiě),同時(shí)能夠準(zhǔn)確地獲得結(jié)果而無(wú)需近似計(jì)算。
Hopper GPU上的新硬件功能:WGMMA、TMA、FP8
雖然FlashAttention-2在Ampere系列(如A100)GPU上,可以達(dá)到理論最大FLOPS的72%,但尚未充分利用Hopper GPU的新功能來(lái)最大化性能。
Hopper特有的一些新功能包括:
1. WGMMA(Warpgroup Matrix Multiply-Accumulate)
這個(gè)新功能利用了Hopper上的新Tensor Cores,相較于Ampere中原來(lái)的mma.sync指令,吞吐量得到大大提高。
2. TMA(Tensor Memory Accelerator)
這是一種特殊的硬件單元,可以加速全局內(nèi)存和共享內(nèi)存之間的數(shù)據(jù)傳輸,并負(fù)責(zé)所有的索引計(jì)算和越界預(yù)測(cè)。這代替了寄存器的部分工作,從而能夠釋放寄存器資源,用于增加塊大小、提高效率。
3. 低精度的FP8
FP8能夠使Tensor Core的吞吐量翻倍,例如,用FP16實(shí)現(xiàn)989 TFLOPS計(jì)算量的同時(shí),F(xiàn)P8能達(dá)到1978 TFLOPS。但由于使用更少的位來(lái)表示浮點(diǎn)數(shù),犧牲了一些計(jì)算準(zhǔn)確性。
進(jìn)化后的FlashAttention-3,充分利用了Hopper GPU的以上所有新功能,并使用了NVIDIA CUTLASS庫(kù)的強(qiáng)大抽象。
僅僅是用這些功能重寫(xiě)FlashAttention,就顯著加快了速度,從FlashAttention-2 FP16前向計(jì)算的350 TFLOPS提升到大約540-570 TFLOPS。
不過(guò),Hopper上新指令(WGMMA和TMA)的異步性,提供了另一種方式——通過(guò)重疊操作來(lái)提取更高的性能。
具體來(lái)說(shuō),研究人員開(kāi)發(fā)了新技術(shù)來(lái)重疊矩陣函數(shù)和softmax的新技術(shù)。
異步處理:重疊GEMM和Softmax
為什么要重疊?
在注意力機(jī)制中,主要涉及兩種操作:GEMM(即Q和K之間的矩陣乘法,以及注意力概率P和V之間的矩陣乘法)和softmax。
為什么需要將它們重疊呢?
大部分的浮點(diǎn)運(yùn)算不都是在GEMM中進(jìn)行的嗎?
只要GEMM足夠快(例如使用WGMMA指令進(jìn)行計(jì)算),GPU不就應(yīng)該一直高速運(yùn)轉(zhuǎn)嗎?
事實(shí)上,并不是GEMM的問(wèn)題,而是softmax會(huì)占用令人驚訝的大量時(shí)間。
問(wèn)題在于,在現(xiàn)代加速器上,非矩陣乘法操作的速度遠(yuǎn)不及矩陣乘法操作。
像指數(shù)函數(shù)(用于softmax)這樣的特殊函數(shù),其吞吐量甚至比浮點(diǎn)乘法加法還低。
它們是由多功能單元(multi-function unit)計(jì)算的,與負(fù)責(zé)浮點(diǎn)乘加或矩陣乘加運(yùn)算的單元分開(kāi)計(jì)算。
例如,H100 GPU SXM5的FP16矩陣乘法性能可以達(dá)到989 TFLOPS,但特殊函數(shù)的性能只有3.9 TFLOPS(吞吐量低了256倍)!
head維度為128時(shí),矩陣乘法的FLOPS運(yùn)算是指數(shù)函數(shù)的512倍,這意味著指數(shù)函數(shù)的計(jì)算時(shí)間可以占到矩陣乘法的一半。
對(duì)于FP8,情況更糟,因?yàn)榫仃嚦朔ǖ倪\(yùn)算速度是指數(shù)函數(shù)的兩倍,但指數(shù)函數(shù)的速度卻沒(méi)有變化。
因此,理想情況下,應(yīng)該讓矩陣乘法和softmax并行操作。
當(dāng)Tensor Cores忙于矩陣乘法時(shí),多功能單元應(yīng)該在計(jì)算指數(shù)函數(shù)!
通過(guò)乒乓調(diào)度實(shí)現(xiàn)跨warp組重疊
第一種,也是最簡(jiǎn)單的重疊GEMM和softmax的方法,那就是什么都不做!
warp調(diào)度器已經(jīng)在嘗試調(diào)度warp,當(dāng)某些warp被阻塞(例如,等待GEMM結(jié)果)時(shí),其他warp可以繼續(xù)運(yùn)行。
也就是說(shuō),warp調(diào)度器已經(jīng)在幫研究者做一些重疊工作了,而且不引入額外成本。
然而,我們依舊可以通過(guò)手動(dòng)調(diào)度來(lái)進(jìn)一步優(yōu)化。
例如,如果有兩個(gè)warp組(標(biāo)記為1和2,每個(gè)warp組包含4個(gè)warp),可以使用同步屏障(bar.sync),使得warp組1先執(zhí)行其GEMM指令(例如,GEMM1的一次迭代和GEMM0的下一次迭代),然后warp組2執(zhí)行其GEMM,同時(shí)warp組1執(zhí)行其softmax,依次循環(huán)。
下圖展示了這種「乒乓」調(diào)度,其中相同顏色表示相同的迭代。
這將使我們能夠在另一個(gè)warp組進(jìn)行GEMM計(jì)算的同時(shí),執(zhí)行softmax操作。
當(dāng)然,這個(gè)圖只是一個(gè)簡(jiǎn)化示意圖;實(shí)際調(diào)度并沒(méi)有這么整齊。
然而,乒乓調(diào)度可以將FP16注意力在前向計(jì)算中的性能從大約570 TFLOPS提高到620 TFLOPS(head維度128,序列長(zhǎng)度8K)。
在單個(gè)warp組內(nèi)重疊GEMM和Softmax
即使在一個(gè)warp組內(nèi),也可以在warp組進(jìn)行GEMM計(jì)算時(shí),同時(shí)運(yùn)行softmax的一部分。
下圖展示了這種情況,其中相同顏色表示相同的迭代。
使用這種流水線,F(xiàn)P16注意力前向計(jì)算的吞吐量從大約620 TFLOPS,提高到大約640-660 TFLOPS,但代價(jià)是增加了寄存器壓力。
這種情況下,就需要更多的寄存器,來(lái)同時(shí)保存GEMM的累加器,和softmax的輸入/輸出。
總之,這種技術(shù)可以提供一種有利的折衷方案。
低精度:通過(guò)非相干處理減少量化誤差
在LLM的激活函數(shù)中,可能會(huì)出現(xiàn)一些比其他特征大得多的異常值,這會(huì)增加量化的難度,并產(chǎn)生更大的量化誤差。
為此,論文采用了一種常用的量化技術(shù)——非相干處理(incoherent processing),例如QuIP論文中描述的,通過(guò)將查詢(xún)和鍵乘以一個(gè)隨機(jī)正交矩陣來(lái)「分散」這些異常值,從而減少量化誤差。
特別的,論文使用Hadamard變換(帶有隨機(jī)正負(fù)號(hào))產(chǎn)生隨機(jī)矩陣,可以在O(d log d)而不是O(d^2)時(shí)間內(nèi)完成每個(gè)注意力頭的計(jì)算,其中d是head維度。
由于Hadamard變換受限于內(nèi)存帶寬,它可以與之前的操作,如旋轉(zhuǎn)嵌入,進(jìn)行無(wú)成本融合,后者同樣受內(nèi)存帶寬的限制。
在實(shí)驗(yàn)中,Q、K、V是從標(biāo)準(zhǔn)正態(tài)分布生成的,但其中的0.1%有更大的數(shù)量級(jí)(以模擬異常值)。
結(jié)果發(fā)現(xiàn),非相干處理可以將量化誤差減少2.6倍。下表展示了數(shù)值誤差對(duì)比。
注意力基準(zhǔn)測(cè)試
接下來(lái),論文展示了一些FlashAttention-3的測(cè)試結(jié)果,并將其與FlashAttention-2以及PyTorch中Triton和cuDNN的注意力實(shí)現(xiàn)進(jìn)行了比較(注意,后兩者都已經(jīng)利用了Hopper GPU的新硬件特性)。
對(duì)于FP16精度,他們發(fā)現(xiàn)FlashAttention-3相對(duì)于FlashAttention-2,有大約1.6倍到2.0倍的加速效果。
序列長(zhǎng)度在在1k或以下時(shí),F(xiàn)A3相比Triton和cuDNN的優(yōu)勢(shì)并不明顯,有時(shí)甚至?xí)浜蟆?/p>
但隨著序列長(zhǎng)度和head維度逐漸增大,F(xiàn)A3與其他實(shí)現(xiàn)方案的差距也越來(lái)越顯著,可見(jiàn)這種算法非常適用于大規(guī)模運(yùn)算場(chǎng)景。
相較于標(biāo)準(zhǔn)注意力,F(xiàn)lashAttention-3的速度快了3-16倍。
對(duì)于FP8精度,F(xiàn)lashAttention-3的性能可以接近1.2 PFLOPS,但會(huì)在某些情況下落后于Triton和cuDNN的性能。
除了前向計(jì)算,F(xiàn)A3后向傳播的運(yùn)算速度也同樣領(lǐng)先其他方案。
以上重點(diǎn)介紹了FlashAttention針對(duì)Hopper GPU新特性實(shí)現(xiàn)的優(yōu)化,此外,論文中也詳述了其他方面的優(yōu)化,包括可變長(zhǎng)度序列、持久內(nèi)核和FP8內(nèi)核轉(zhuǎn)置等。
可以看到,能夠充分利用硬件性能的算法可以顯著提升效率,還能解鎖新的模型能力,比如處理更長(zhǎng)的上下文。
目前,F(xiàn)lashAttetion-3著重訓(xùn)練過(guò)程的優(yōu)化,未來(lái)的工作可以繼續(xù)提升推理性能,并推廣到Hopper GPU以外的其他硬件架構(gòu)。
參考資料:
https://tridao.me/publications/flash3/flash3.pdfhttps://tridao.me/blog/2024/flash3/
公眾號(hào)后臺(tái)回復(fù)“數(shù)據(jù)集”獲取100+深度學(xué)習(xí)各方向資源整理
極市干貨
點(diǎn)擊閱讀原文進(jìn)入CV社區(qū)
收獲更多技術(shù)干貨
