論文閱讀 | Scaling Laws for Transfer
共 6493字,需瀏覽 13分鐘
·
2024-10-16 07:00
論文閱讀 | Scaling Laws for Transfer
作者:張義策
文章地址:https://zhuanlan.zhihu.com/p/710594520
An Empirical Study of Scaling Laws for Transfer
論文地址:https://arxiv.org/abs/2408.16947
這是axriv上的一篇文章,作者來自Epoch AI。
預(yù)訓(xùn)練到下游任務(wù)的知識遷移
為了在下游任務(wù)上取得高性能,一個標(biāo)準(zhǔn)做法是先在一個規(guī)模大、多樣化的語料上預(yù)訓(xùn)練一個基座模型,然后在特定的下游任務(wù)上微調(diào)。這里的微調(diào)也可以使用In-context Learning方法代替。這些方法的有效性取決于從預(yù)訓(xùn)練到下游任務(wù)的知識遷移程度。遷移程度高,那么微調(diào)的成本可以低些;反之,我們就需要多構(gòu)建一些高質(zhì)量的微調(diào)數(shù)據(jù)。
本文的目的是研究上述知識遷移過程的Scaling Laws。本文首先使用一個公式來建模預(yù)訓(xùn)練和微調(diào)對下游任務(wù)的影響,然后在多個下游任務(wù)上進行實驗對這個公式中的參數(shù)進行估計,并展開了一些系列的分析。
Scaling Laws
這篇文章使用下面的公式來建模預(yù)訓(xùn)練和微調(diào)對下游任務(wù)的影響,
其中 p 表示預(yù)訓(xùn)練步長,f 表示微調(diào)階段的樣本數(shù)據(jù)量,L(p,f) 表示在下游任務(wù)上的損失(越小說明在下游任務(wù)上性能越高), G 表示遷移差距(transfer gap,即預(yù)訓(xùn)練語料和下游任務(wù)之間的差距),E 表示無法降低的損失(和任務(wù)相關(guān),是損失的下界), α>0 和 β>0 是預(yù)訓(xùn)練和微調(diào)的衰減系數(shù)。
上面的公式可能比較抽象。讓我們試著理解一下。
當(dāng) f=0 即不進行微調(diào), 。下游任務(wù)的損失和 p 呈負冪次的關(guān)系,如下圖所示。隨著 p 趨近無窮, 。這意味著當(dāng)預(yù)訓(xùn)練做到極致,下游任務(wù)上的損失等于遷移差距和任務(wù)下界的損失之和。
函數(shù)圖像來自https://www.geogebra.org/graphing?lang=zh_CN。
當(dāng) 即預(yù)訓(xùn)練固定步數(shù),記 ,那么有 。下游任務(wù)的損失和 f 也呈負冪次的關(guān)系。隨著 f 趨近無窮, 。這意味著當(dāng)微調(diào)數(shù)量足夠多,下游任務(wù)上的損失逼近下界。
上面的Scaling Law公式有5個參數(shù) ,這都需要根據(jù)實驗來估計得到。
實驗設(shè)置
模型選擇了Pythia-2.8b結(jié)構(gòu),預(yù)訓(xùn)練數(shù)據(jù)集選擇Pile,下游的微調(diào)數(shù)據(jù)集選擇了5個,如下表所示。
預(yù)訓(xùn)練階段的batch_size設(shè)置為2,097,152個tokens,最大預(yù)訓(xùn)練步長為143,000,一共保存了15個checkpoints。接下來對這些checkpoints,在下游數(shù)據(jù)集上進行微調(diào),分別在10種微調(diào)樣本數(shù)量的設(shè)置下,記錄模型的損失。
接下來,對每個下游數(shù)據(jù)集,使用實驗得到的 數(shù)據(jù)來估計Scaling Laws中的參數(shù) 。
實驗結(jié)果
在5個下游數(shù)據(jù)集上參數(shù)估計的結(jié)果如下表所示。
我們可以得到如下的分析結(jié)果:
-
預(yù)訓(xùn)練的衰減系數(shù) 在不同任務(wù)上差別不大。相比之下,微調(diào)的衰減系數(shù) 的差異就大很多了。 -
差別不大,那預(yù)訓(xùn)練對不同任務(wù)的差異主要取決于遷移差距 ,也就是預(yù)訓(xùn)練數(shù)據(jù)和微調(diào)數(shù)據(jù)的分布差異。 -
令作者感到驚訝的是,house cat genome(一個基因數(shù)據(jù)集)上的遷移差距很小,只有0.548。作者認為這是因為這個數(shù)據(jù)集上的無法降低損失 E 比較大,達到了2.677,表明這個數(shù)據(jù)集具有比較高的內(nèi)在熵、相比來說不好學(xué)習(xí)。 -
盡管對每個任務(wù)只使用了150個數(shù)據(jù)點,但是每個參數(shù)的置信區(qū)間還是比較小,說明整體估計方法還是robust的。(這部分結(jié)果在論文的表3中)。 此外,作者還估計了在更大尺寸模型上的進行實驗的計算代價。對于Llama 3 70B,需要4.77 10^16的FLOP。
進一步的討論:預(yù)訓(xùn)練和微調(diào)的trade-off
預(yù)訓(xùn)練和微調(diào)對下游任務(wù)損失的影響 考慮這樣一個問題:當(dāng)我們用于預(yù)訓(xùn)練和微調(diào)的預(yù)算固定,我們應(yīng)該怎樣分配這些預(yù)算來達到下游任務(wù)上的損失最小?使用Scaling Law for Transfer,我們就可以決定是多花錢來收集更多微調(diào)數(shù)據(jù),還是用來做更大規(guī)模的預(yù)訓(xùn)練。
具體來說,記預(yù)算為 B , 為預(yù)訓(xùn)練一步的成本, 為收集一條微調(diào)數(shù)據(jù)的成本。這樣的話,上面的成本問題就可以形式化為下面的優(yōu)化問題:
基于這樣的優(yōu)化目標(biāo),我們可以大致得到下面的結(jié)論:
-
當(dāng)遷移差距較小的時候,應(yīng)該多做預(yù)訓(xùn)練;如果遷移差距較大,應(yīng)該多收集微調(diào)數(shù)據(jù)。 -
如果 比較大,也就是收集微調(diào)數(shù)據(jù)的成本比較高,那應(yīng)該多花預(yù)算在預(yù)訓(xùn)練上。(感覺在說廢話)
