在RTX 4090被限制的時(shí)代下,讓大模型使用RLHF更高效的方法來了
本文轉(zhuǎn)載自人工智能與算法學(xué)習(xí)
該論文介紹了一種名為 ReMax 的新算法,專為基于人類反饋的強(qiáng)化學(xué)習(xí)(RLHF)而設(shè)計(jì)。ReMax 在計(jì)算效率(約減少 50% 的 GPU 內(nèi)存和 2 倍的訓(xùn)練速度提升)和實(shí)現(xiàn)簡(jiǎn)易性(6 行代碼)上超越了最常用的算法 PPO,且性能沒有損失。

-
論文鏈接:https://arxiv.org/abs/2310.10505 -
作者:李子牛,許天,張雨舜,俞揚(yáng),孫若愚,羅智泉 -
機(jī)構(gòu):香港中文大學(xué)(深圳),深圳市大數(shù)據(jù)研究院,南京大學(xué),南棲仙策 -
開源代碼:https://github.com/liziniu/ReMax






REINFORCE可以在計(jì)算層面利用好RLHF任務(wù)的三個(gè)性質(zhì),因?yàn)镽EINFORCE直接利用一個(gè)響應(yīng)的獎(jiǎng)勵(lì)來進(jìn)行優(yōu)化,不需要像一般的RL算法一樣需要知道中間步驟的獎(jiǎng)勵(lì)和值函數(shù)。然而,由于策略的隨機(jī)性, REINFORCE梯度估計(jì)器存在高方差問題(在Richard Sutton的RL書里有指出),這一問題會(huì)影響模型訓(xùn)練的有效性,因此REINFORCE在RLHF任務(wù)中的效果較差,見下面兩張圖片。



可以看作為期望獎(jiǎng)勵(lì)
的好的近似。在理想情形下(
),對(duì)于隨機(jī)變量
,
,因此我們能夠期望估計(jì)器
具有更小的方差。

-
ReMax 的核心部分可以用 6 行代碼來實(shí)現(xiàn)。相比之下,PPO 要額外引入重要性采樣(importance sampling),廣義優(yōu)勢(shì)估計(jì)(generalized advantage estimation,GAE),價(jià)值模型學(xué)習(xí)等額外模塊。 -
ReMax 的超參數(shù)很少。相比之下,PPO 有額外的超參數(shù),例如重要性采樣剪切閾值(importance sampling clipping ratio)、GAE 系數(shù)、價(jià)值模型學(xué)習(xí)率,離策略訓(xùn)練輪次(off-policy training epoch)等,這些超參數(shù)都需要花大量時(shí)間去調(diào)優(yōu)。 -
ReMax 能理論上節(jié)省約 50% 內(nèi)存。相比于 PPO,ReMax 成功移除了所有和價(jià)值模型相關(guān)的部件,大大減小了內(nèi)存開銷。通過計(jì)算,我們發(fā)現(xiàn)相比于 PPO,ReMax 能節(jié)省約 50% 內(nèi)存。
-
ReMax 可以像 PPO 一樣有效地最大化獎(jiǎng)勵(lì)


-
在 GPT-4 評(píng)估下(LIMA Test Questions),ReMax 得到的策略比 SFT 和 PPO 會(huì)更好

-
ReMax 能節(jié)省近 50% 的 GPU 內(nèi)存。ReMax 移除掉了價(jià)值模型和它的訓(xùn)練部分(梯度,優(yōu)化器,激活值),從而極大節(jié)省了 GPU 內(nèi)存需求。考慮 Llama2-7B,PPO 無法在 8xA100-40GB 的機(jī)器上跑起來,但是 ReMax 可以。

-
ReMax 能加快 2 倍的訓(xùn)練速度。在每一輪中,ReMax 調(diào)用 2 次生成(generation),1 次反向傳播(backpropagation);而 PPO 使用 1 次生成,2 次反向傳播。對(duì)于大模型而言,生成會(huì)比反向傳播的時(shí)間小,從而 ReMax 可以實(shí)現(xiàn)理論上接近 2 倍的訓(xùn)練加速。


-
更簡(jiǎn)單的實(shí)現(xiàn): ReMax 的核心部分 6 行代碼即可實(shí)現(xiàn)。這與 PPO 中的眾多復(fù)雜的代碼構(gòu)建塊形成鮮明對(duì)比。 -
更少的內(nèi)存開銷:由于移除了價(jià)值模型及其全部訓(xùn)練組件,相比 PPO,ReMax 節(jié)省了大約 50% 的 GPU 內(nèi)存。 -
更少的超參數(shù): ReMax 成功移除了所有和價(jià)值模型訓(xùn)練相關(guān)的超參數(shù),其中包括:GAE 系數(shù)、價(jià)值模型學(xué)習(xí)率、重要性采樣時(shí)期、小批量(mini-batch)大小。這些超參數(shù)往往對(duì)問題敏感且難以調(diào)整。我們相信 ReMax 對(duì) RLHF 研究者更加友好。 -
更快的訓(xùn)練速度:在 GPT2(137M)的實(shí)驗(yàn)中,我們觀察到 ReMax 在真實(shí)運(yùn)行時(shí)間方面相比于 PPO 有 2.2 倍的加速。加速來自 ReMax 每次迭代中較少的計(jì)算開銷。通過我們的計(jì)算,該加速優(yōu)勢(shì)在更大的模型上也能維持(假設(shè)在足夠大的內(nèi)存下 PPO 可以被成功部署)。 -
優(yōu)異的性能:如前所示,ReMax在中等規(guī)模實(shí)驗(yàn)中與PPO實(shí)現(xiàn)了相當(dāng)?shù)男阅埽⑶矣袝r(shí)甚至超越它(可能是由于 ReMax 更容易找到合適的超參數(shù))。我們推測(cè)這種良好的性能可以拓展到更大規(guī)模的模型中
評(píng)論
圖片
表情
