老黃又贏麻了!英偉達(dá)親自下場推出 FlashAttention-3:H100利用率飆升至75%!
共 3732字,需瀏覽 8分鐘
·
2024-07-12 14:11
點藍(lán)色字關(guān)注“機器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
740 TFLOPS!迄今最強 FlashAttention 來了。
隨著大型語言模型(LLM)加速落地,擴展模型上下文窗口變得越來越重要。然而,Transformer 架構(gòu)的核心 —— 注意力層的時間復(fù)雜度和空間復(fù)雜度與輸入序列長度的平方成正比。這使得擴展模型上下文窗口存在挑戰(zhàn)。
2022 年,一種快速、內(nèi)存高效的注意力算法 ——FlashAttention 問世,該算法無需任何近似即可加速注意力并減少內(nèi)存占用。
FlashAttention 對注意力計算進行重新排序的算法,并利用 tiling 和重計算來顯著加快計算速度,將內(nèi)存使用量從序列長度的二次減少到線性。
2023 年,研究團隊宣布推出 FlashAttention-2,在算法、并行化和工作分區(qū)等方面有了顯著改進。
現(xiàn)在,來自 Meta、英偉達(dá)、Together AI 等機構(gòu)的研究者宣布推出 FlashAttention-3,它采用了加速 Hopper GPU 注意力的三種主要技術(shù):
通過 warp-specialization 重疊整體計算和數(shù)據(jù)移動;
交錯分塊 matmul 和 softmax 運算;
利用硬件支持 FP8 低精度的不連貫處理。
FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高達(dá) 740 TFLOPS,即 H100 理論最大 FLOPS 利用率為 75%。使用 FP8,F(xiàn)lashAttention-3 的速度更是接近 1.2 PFLOPS。
FlashAttention-3 的改進將帶來:
更高效的 GPU 利用率:H100 理論最大 FLOPS 利用率為 75%,而之前僅為 35%。這使得 LLM 的訓(xùn)練和運行速度比以前的版本快得多。
較低精度下更好的性能:FlashAttention-3 可以在保持精度的同時使用較低精度的數(shù)字 (FP8)。這可以實現(xiàn)更快的處理速度并可能降低內(nèi)存使用量,從而為運行大規(guī)模人工智能操作的客戶節(jié)省成本并提高效率。
能夠在 LLM 中使用更長的上下文:通過加速注意力機制,F(xiàn)lashAttention-3 使 AI 模型能夠更有效地處理更長的文本片段。這使得應(yīng)用程序能夠理解并生成更長、更復(fù)雜的內(nèi)容而不會減慢速度。
論文標(biāo)題:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
論文地址:https://tridao.me/publications/flash3/flash3.pdf
論文作者之一 、FlashAttention1-3 版本的參與者 Tri Dao 表示:FlashAttention 被廣泛用于加速 Transformers,已經(jīng)使注意力速度提高了 4-8 倍,但尚未利用現(xiàn)代 GPU。因而他們發(fā)布了 FlashAttention-3:在 FP16 上速度提高了 1.5-2 倍,在 H100 上高達(dá) 740 TFLOPS(75% 實用性),F(xiàn)P8 接近 1.2 PFLOPS!
Hopper GPU 硬件特性:WGMMA、TMA、FP8
雖然 FlashAttention-2 在 Ampere (A100) GPU 上可以實現(xiàn) 70% 的理論最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能來最大限度地提高性能。接下來文章描述了一些新的 Hopper 特定功能,以及它們?yōu)楹稳绱酥匾?/span>
首先是 WGMMA(Warpgroup Matrix Multiply-Accumulate),該功能利用了 Hopper 架構(gòu)上新的張量內(nèi)核,比 Ampere 架構(gòu)具有更高的吞吐量。
然后是 TMA(Tensor Memory Accelerator),這是一個特殊的硬件單元,可以加速全局內(nèi)存和共享內(nèi)存之間的數(shù)據(jù)傳輸,用于處理所有索引計算和邊界外預(yù)測。這樣一來寄存器就釋放了,寄存器是增加 tile 大小和效率的寶貴資源。
低精度 FP8,讓 Tensor Core 吞吐量翻了一倍。
FlashAttention-3 充分利用了 Hopper 架構(gòu)的所有這些新功能。
異步:GEMM 和 Softmax 重疊
注意力機制主要有兩個操作,GEMM 和 softmax。為什么要將它們重疊?
問題在于在現(xiàn)代加速器上,非矩陣乘法(matmul)運算比矩陣乘法運算慢。特殊函數(shù)如指數(shù)運算(如 softmax 函數(shù))的吞吐量甚至低于浮點乘加操作;這些運算是由多功能單元處理的,這是一個與浮點乘加或矩陣乘加不同的單元。
理想情況下,研究者希望矩陣乘法和 softmax 能夠并行操作。當(dāng) Tensor Cores 忙于矩陣乘法時,多功能單元應(yīng)當(dāng)在計算指數(shù)運算!
Inter-warpgroup 重疊
重疊 GEMM 和 softmax 最簡單的方法是什么都不做,warp 調(diào)度程序會免費完成部分重疊。下圖說明了 pingpong 調(diào)度,其中相同的顏色表示相同的迭代。
Intra-warpgroup 重疊
即使在一個 warpgroup 中,研究者也可以在運行該 warpgroup 的 GEMM 時運行 softmax 的某些部分。如圖所示,相同的顏色表示相同的迭代。
這種 pipeline 流程可以將 FP16 注意力前向傳播的吞吐量從大約 620 TFLOPS 提高到 640-660 TFLOPS,但代價是更高的寄存器壓力,因而需要更多的寄存器來同時保存 GEMM 的累加器以及 Softmax 的輸入 / 輸出。
低精度:使用非相干處理減少量化誤差
激活 LLM 可能存在一些極端值,導(dǎo)致量化困難,從而產(chǎn)生較大的量化誤差。本文采用非相干處理(incoherent processing),該技術(shù)通過將查詢和鍵與一個隨機正交矩陣相乘來「分散(spread out)」極端值,從而減少量化誤差。特別地,該研究使用了 Hadamard 變換,它可以在每個注意力頭中以 O (d log d) 的時間復(fù)雜度完成,而不是 O (d^2),其中 d 是頭部維度。
研究者發(fā)現(xiàn)非相干處理可以將量化誤差減少很多,具體的數(shù)值誤差比較見下表。
實驗
文中展示了 FlashAttention-3 的一些結(jié)果,并將其與 FlashAttention-2 以及 Triton 和 cuDNN 中的實現(xiàn)進行了比較(兩者都已經(jīng)使用了 Hopper GPU 的新硬件功能)。
在 FP16 精度下,F(xiàn)lashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍。
對于 FP8,F(xiàn)lashAttention-3 接近 1.2 PFLOPS。
擴展閱讀:
斯坦福提出新型Attention算法!提速2-4倍,BERT單節(jié)點訓(xùn)練最快
比標(biāo)準(zhǔn)Attention提速5-9倍,大模型都在用的FlashAttention v2來了
參考鏈接:
https://tridao.me/blog/2024/flash3/ 轉(zhuǎn)自機器之心
推薦閱讀
使用PyTorch 2.0加速Transformer:訓(xùn)練推理均拿下!
機器學(xué)習(xí)算法工程師
一個用心的公眾號
