【深度學(xué)習(xí)】干貨!小顯存如何訓(xùn)練大模型
之前Kaggle有一個(gè)Jigsaw多語言毒舌評論分類[1]比賽,當(dāng)時(shí)我只有一張11G顯存的1080Ti,根本沒法訓(xùn)練SOTA的Roberta-XLM-large模型,只能遺憾躺平。在這篇文章中,我將分享一些關(guān)于如何減少訓(xùn)練時(shí)顯存使用的技巧,以便你可以用現(xiàn)有的GPU訓(xùn)練更大的網(wǎng)絡(luò)。
混合精度訓(xùn)練
第一個(gè)可能已經(jīng)普及的技巧是使用混合精度(mixed-precision)訓(xùn)練。當(dāng)訓(xùn)練一個(gè)模型時(shí),一般來說所有的參數(shù)都會存儲在顯存VRAM中。很簡單,總的VRAM使用量等于存儲的參數(shù)數(shù)量乘以單個(gè)參數(shù)的VRAM使用量。一個(gè)更大的模型不僅意味著更好的性能,而且也會使用更多的VRAM。由于性能相當(dāng)重要,比如在Kaggle比賽中,我們不希望減小模型的規(guī)模。因此減少顯存使用的唯一方法是減少每個(gè)變量的內(nèi)存使用。默認(rèn)情況下變量是32位浮點(diǎn)格式,這樣一個(gè)變量就會消耗4個(gè)字節(jié)。幸運(yùn)的是,人們發(fā)現(xiàn)可以在某些變量上使用16位浮點(diǎn),而不會損失太多的精度。這意味著我們可以減少一半的內(nèi)存消耗! 此外,使用低精度還可以提高訓(xùn)練速度,特別是在支持Tensor Core的GPU上。
在1.5版本之后,pytorch開始支持自動(dòng)混合精度(AMP)訓(xùn)練。該框架可以識別需要全精度的模塊,并對其使用32位浮點(diǎn)數(shù),對其他模塊使用16位浮點(diǎn)數(shù)。下面是Pytorch官方文檔[2]中的一個(gè)示例代碼。
#?Creates?model?and?optimizer?in?default?precision
model?=?Net().cuda()
optimizer?=?optim.SGD(model.parameters(),?...)
#?Creates?a?GradScaler?once?at?the?beginning?of?training.
scaler?=?GradScaler()
for?epoch?in?epochs:
????for?input,?target?in?data:
????????optimizer.zero_grad()
????????#?Runs?the?forward?pass?with?autocasting.
????????with?autocast():
????????????output?=?model(input)
????????????loss?=?loss_fn(output,?target)
????????#?Scales?loss.??Calls?backward()?on?scaled?loss?to?create?scaled?gradients.
????????#?Backward?passes?under?autocast?are?not?recommended.
????????#?Backward?ops?run?in?the?same?dtype?autocast?chose?for?corresponding?forward?ops.
????????scaler.scale(loss).backward()
????????#?scaler.step()?first?unscales?the?gradients?of?the?optimizer's?assigned?params.
????????#?If?these?gradients?do?not?contain?infs?or?NaNs,?optimizer.step()?is?then?called,
????????#?otherwise,?optimizer.step()?is?skipped.
????????scaler.step(optimizer)
????????#?Updates?the?scale?for?next?iteration.
????????scaler.update()
梯度積累
第二個(gè)技巧是使用梯度積累。梯度累積的想法很簡單:在優(yōu)化器更新參數(shù)之前,用相同的模型參數(shù)進(jìn)行幾次前后向傳播。在每次反向傳播時(shí)計(jì)算的梯度被累積(加總)。如果你的實(shí)際batch size是N,而你積累了M步的梯度,你的等效批處理量是N*M。然而,訓(xùn)練結(jié)果不會是嚴(yán)格意義上的相等,因?yàn)橛行﹨?shù),如Batch Normalization,不能完全累積。
關(guān)于梯度累積,有一些事情需要注意:
當(dāng)你在混合精度訓(xùn)練中使用梯度累積時(shí), scale應(yīng)該為有效批次進(jìn)行校準(zhǔn),scale更新應(yīng)該以有效批次的粒度進(jìn)行。當(dāng)你在分布式數(shù)據(jù)并行(DDP)訓(xùn)練中使用梯度累積時(shí),使用no_sync()上下文管理器來禁用前M-1步的梯度全還原,這可以增加訓(xùn)練的速度。
具體的實(shí)現(xiàn)方法可以參考文檔[3]。
梯度檢查點(diǎn)
最后一個(gè),也是最重要的技巧是使用梯度檢查點(diǎn)(Gradient Checkpoint)。Gradient Checkpoint的基本思想是只將一些節(jié)點(diǎn)的中間結(jié)果保存為checkpoint,在反向傳播過程中對這些節(jié)點(diǎn)之間的其他部分進(jìn)行重新計(jì)算。據(jù)Gradient Checkpoint的作者說[4],在這個(gè)技巧的幫助下,他們可以把10倍大的模型放到GPU上,而計(jì)算時(shí)間只增加20%。Pytorch從0.4.0版本開始正式支持這一功能,一些非常常用的庫如Huggingface Transformers也支持這一功能,而且非常簡單,只需要下面的兩行代碼:
bert?=?AutoModel.from_pretrained(pretrained_model_name)
bert.config.gradient_checkpointing=True
實(shí)驗(yàn)
在這篇文章的最后,我想分享之前我在惠普Z4工作站上做的一個(gè)簡單的benchmark。該工作站配備了2個(gè)24G VRAM的RTX6000 GPU(去年底升級到2個(gè)48G的A6000了),在實(shí)驗(yàn)中我只用了一個(gè)GPU。我用不同的配置在Kaggle Jigsaw多語言毒舌評論分類比賽的訓(xùn)練集上訓(xùn)練了XLM-Roberta Base/Large,并觀察顯存的使用量,結(jié)果如下。
| Model | XLM-R Base | XLM-R Base 1 | XLM-R Base 2 | XLM-R Large | XLM-R Large 1 | XLM-R Large 2 |
|---|---|---|---|---|---|---|
| Batch size/GPU | 8 | 8 | 16 | 8 | 8 | 8 |
| Mixed-precision | off | on | on | off | on | on |
| gradient checkpointing | off | off | off | off | off | on |
| VRAM usage | 12.28G | 10.95G | 16.96 | OOM | 23.5G | 11.8G |
| one epoch | 70min | 50min | 40min | - | 100min | 110min |
我們可以看到,混合精度訓(xùn)練不僅減少了內(nèi)存消耗,而且還帶來了顯著的速度提升。梯度檢查點(diǎn)的功能也非常強(qiáng)大。它將VRAM的使用量從23.5G減少到11.8G!
以上就是所有內(nèi)容,希望對大家有幫助??
參考資料
Jigsaw多語言毒舌評論分類: https://www.kaggle.com/c/jigsaw-multilingual-toxic-comment-classification
[2]Pytorch官方文檔: https://pytorch.org/docs/1.8.1/notes/amp_examples.html
[3]gradient-accumulation文檔: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation
[4]據(jù)Gradient Checkpoint的作者說: https://github.com/cybertronai/gradient-checkpointing
往期精彩回顧
