如何用數(shù)據(jù)并行訓(xùn)練萬億參數(shù)模型?
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
近期,F(xiàn)acebook發(fā)布了FSDP(Fully Sharded Data Parallel),這個(gè)是對(duì)標(biāo)微軟在DeepSpeed中提出的ZeRO,F(xiàn)SDP可以看成PyTorch中的DDP優(yōu)化版本,本身也是數(shù)據(jù)并行,但是和DDP不同的是,F(xiàn)SDP采用了parameter sharding,所謂的parameter sharding就是將模型參數(shù)也切分到各個(gè)GPUs上,而DDP每個(gè)GPU都要保存一份parameter,F(xiàn)SDP可以實(shí)現(xiàn)更好的訓(xùn)練效率(速度和顯存使用)。這背后的優(yōu)化邏輯可以從谷歌和微軟的論文中找到。
Sharding weight update
對(duì)于典型的數(shù)據(jù)并行實(shí)現(xiàn)(PyTorch的DDP和TF的tf.distribute.MirroredStrategy)來說,每個(gè)replica(GPU)都擁有一份模型參數(shù)和一套o(hù)ptimizer,每個(gè)訓(xùn)練step,數(shù)據(jù)被均分到每個(gè)replica上,每個(gè)replica基于被分到的數(shù)據(jù)單獨(dú)計(jì)算自己的local gradients,然后所有的replicas基于all-reduce操作來得到local gradients的summed gradients,這樣每個(gè)replica其實(shí)都拿到了global gradients,最后基于global gradients更新模型參數(shù)(weight update)。這個(gè)過程如下圖所示:
其中all-reduce操作(ring all-reduce)包含兩個(gè)操作:reduce-scatter和all-gather。在reduce-scatter階段,gradients被均分成不同的blocks或shards,通過N-1輪交換數(shù)據(jù),每個(gè)replica都得到一份reduced后的shards;在all-grather階段,通過N-1輪數(shù)據(jù)交換,每個(gè)replica都將自己的那份reduced后的shards廣播到其它的replicas,這樣所有的replicas就能得到全部reduced后的gradients。不論有多少replicas,all-reduce的通信成本上是恒定的,這樣就可以實(shí)現(xiàn)線性加速。
每個(gè)replicas拿到reduced gradients后都在做重復(fù)的update weight,因?yàn)槊總€(gè)replicas都有模型參數(shù)的一個(gè)copy。如果模型(如NLP中的Transformer)比較大,參數(shù)量多,這個(gè)update weight在訓(xùn)練step中就會(huì)占據(jù)不可忽略的耗時(shí);對(duì)于小模型的大規(guī)模分布式訓(xùn)練,一般每個(gè)device會(huì)采用較小的batch size以防止global batch size過大,此時(shí)update weight也會(huì)成為訓(xùn)練step中的重要耗時(shí)項(xiàng)。為了解決這個(gè)問題,谷歌在2020年提出了sharding weight update,如下圖所示,經(jīng)過reduce-scatter后每個(gè)replica得到一個(gè)gradient shard,每個(gè)replica先更新自己的shard的weight,然后再進(jìn)行all-gather,這樣其實(shí)是和原始的all-reduce是等價(jià)的。但是經(jīng)過這個(gè)調(diào)整,每個(gè)replica只是update weight shard,耗時(shí)就會(huì)降低了,相當(dāng)于update weight也被各個(gè)replica給分擔(dān)了。

另外一點(diǎn)就是要考慮optimizer,因?yàn)閛ptimizer往往包含額外的參數(shù),比如SGD包含一套參數(shù):gradient的EMA,而Adam包含兩套參數(shù):gradient的EMA和variance,這些參數(shù)可以統(tǒng)稱為optimizer states,它們也是需要同步更新的。當(dāng)模型參數(shù)較大時(shí),optimizer states也會(huì)很大,比如Adam就是模型參數(shù)的2倍,如果也對(duì)optimizer states進(jìn)行all-gather的話,通信成本就會(huì)比較大(原始的all-reduce并不需要)。optimizer states只參與weight update中,但是在下一個(gè)forward和backward中并不需要,不過optimizer states應(yīng)該被包含在模型的checkpoints中,因?yàn)樗鼈円彩莟raining state,比較好的方案是只有當(dāng)需要時(shí)才對(duì)optimizer states進(jìn)行all-gather,這就變成如下圖所示:
這里optimizer的auxliary只在Loop body外面才進(jìn)行all-gather以得到final auxliary。另外左圖和右圖的區(qū)別是weight的all-gather的位置不同,左圖weight的all-gather是在update后立即進(jìn)行的,而右圖是在需要的時(shí)候(forward和backward)才進(jìn)行all-gather,看起來像是左邊的方案更好一點(diǎn),因?yàn)樵谧詈蟮玫絝inal weight時(shí)右圖還需要一次all-gather。但是右圖方案有更大的優(yōu)化空間,這是因?yàn)樵趂orward和backward過程中往往不需要高精度的weight,比如TPU中可以采用bfloat16,雖然update weight需要float32。在右圖方案中,可以采用低精度bfloat16來all-gather來得到所需要的全部weight,這樣就大大降低了內(nèi)存使用和通信成本。另外weight和auxliary weight的生存周期也減少了。特別是optimizer的auxliary weight,在training loop中其實(shí)只需要shard,這樣就節(jié)省一部分內(nèi)存空間,可以用來存儲(chǔ)forward和backward中activations和gradients。假定模型參數(shù)大小是W,而auxliary weight大小是V,共有N個(gè)shards,orward和backward中activations和gradients的峰值大小是P,那么訓(xùn)練的峰值大小就從W+V+P降低為max(W+V/N+P,W+V),這帶來的一個(gè)好處是Adam將和SGD一樣高效(Adam比SGD要多一份auxliary weight)。
可以看到谷歌提出的sharding weight update不僅可以加速訓(xùn)練,而且也會(huì)節(jié)省顯存,這里只是簡單介紹了論文最核心的優(yōu)化邏輯,論文中還有關(guān)于graph和shard具體實(shí)現(xiàn)細(xì)節(jié)討論。論文中基于ResNet-50,Transformer和NCF三個(gè)模型做實(shí)驗(yàn),實(shí)驗(yàn)配置如下:
從實(shí)驗(yàn)結(jié)果來看,無論是CV還是NLP模型在訓(xùn)練耗時(shí)和顯存使用上均有提升,特別是對(duì)大規(guī)模訓(xùn)練的場景(replica batch size?。┖湍P洼^大的場景(Transformer模型):
ZeRO-DP
微軟在DeepSpeed中提出的ZeRO(Zero Redundancy Optimizer)出發(fā)點(diǎn)是優(yōu)化內(nèi)存使用,從而提高訓(xùn)練速度,并且可以實(shí)現(xiàn)訓(xùn)練更大的模型。ZeRO包含模型并行ZeRO-R和數(shù)據(jù)并行ZeRO-DP,這里我們只討論數(shù)據(jù)并行ZeRO-DP。ZeRO-DP的出發(fā)點(diǎn)是優(yōu)化model states,這里的model states包括:optimizer states, gradients and parameters,其中optimizer states前面已經(jīng)說過,就是optimizer所需要的參數(shù),對(duì)于Adam其optimizer states是parameters的2倍,而且使用混合精度訓(xùn)練時(shí),optimizer states是fp32,這將成為顯存占用的大頭。
在混合精度訓(xùn)練中,訓(xùn)練的forward和backward采用的是fp16 weights,activations和gradients,但是weight update需要采用fp32,這就需要optimizer保存一份fp32 weights,而且optimizer states也要采用fp32。假定模型大小是 Ψ,而gradients和parameters均采用fp16,那么消耗的顯存是2Ψ+2Ψ。而Adam需要fp32的parameters,momentum和variance(optimizer states),其消耗的顯存是4Ψ+4Ψ+4Ψ。用K來表示optimizer states的multiplier,那么model states消耗的顯存是(4+K)*Ψ,對(duì)于Adam來說K=12,那么model states消耗的顯存是16Ψ。ZeRO-DP的優(yōu)化策略就是分別對(duì)model states各個(gè)部分進(jìn)行partitioning:
Optimizer State Partitioning
如果DP的并行度為(replicas數(shù)量),那么可以將optimizer state均分為個(gè)partitions,這樣第i個(gè)節(jié)點(diǎn)只需要更新optimizer state第i個(gè)partition。此時(shí)每個(gè)節(jié)點(diǎn)只需要存儲(chǔ)和更新所有optimizer state的,而且也只更新parameter的。在每個(gè)training step的最后,只需要執(zhí)行all-gather,每個(gè)節(jié)點(diǎn)就可以獲得更新后的全部parameter??梢杂?jì)算,optimizer State partitioning()消耗的顯存就減少為。這個(gè)優(yōu)化其實(shí)前面谷歌的工作也做了。
Gradient Partitioning
既然每個(gè)節(jié)點(diǎn)只需要更新parameter的,那么其實(shí)每個(gè)節(jié)點(diǎn)也只需要對(duì)應(yīng)參數(shù)的gradient。具體地,在backward過程的每個(gè)layer,一旦得到了gradient,每個(gè)節(jié)點(diǎn)就對(duì)自己所需那部分參數(shù)的gradient做reduce(等價(jià)于做一個(gè)reduce-scatter),得到summed gradients,由于其它部分的gradient并不需要了就可以釋放了,從而減少了顯存使用,這可以稱為gradient partitioning()。此時(shí)顯存的消耗降為。
Parameter Partitioning
更進(jìn)一步地,其實(shí)每個(gè)節(jié)點(diǎn)只需要存儲(chǔ)要更新的那部分參數(shù)就好,在forward和backward過程中,需要全部的weight時(shí)再進(jìn)行all-gather,然后再丟棄,這就是parameter partitioning(),此時(shí)顯存的消耗進(jìn)一步減低為。但是采用parameter partitioning是通信開銷的,論文中實(shí)驗(yàn)說明使用后通信成本增大1.5倍。
基于ZeRO-DP,當(dāng)時(shí),1T Model(萬億參數(shù))消耗的顯存為15.6GB,模型可以放在一張32GB的V100卡上。
其實(shí)可以看到,谷歌的sharding weight update近似等價(jià)于采用的ZeRO-DP,雖然兩個(gè)工作的出發(fā)點(diǎn)不一樣,但是殊途同歸。
FSDP
其實(shí)在FSDP之前,F(xiàn)acebook已經(jīng)實(shí)現(xiàn)了optimizer state+gradient sharding DP,這就是采用的ZeRO-DP,或者叫ZeRO-DP-2,這個(gè)實(shí)現(xiàn)包含在fairscale庫中,一個(gè)具體的使用case如下所示:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
# DDP init example
dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
# Problem statement
model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
# Wrap the optimizer in its state sharding brethren
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
train,
args=(
WORLD_SIZE,
EPOCHS,
),
nprocs=WORLD_SIZE,
join=True,
)
而最近發(fā)布的FSDP更是實(shí)現(xiàn)了完全的ZeRO-DP,而且據(jù)官方說效率更高,更重要的是FSDP可以直接替換PyTorch的DDP,F(xiàn)SDP的特點(diǎn)如下:
FSDP對(duì)parameters (FP16 + FP32)和optimizer state進(jìn)行sharding; 當(dāng)reshard_after_forward=False,和PyTorch DDP通信成本一樣,類似ZeRO-DP-2; 當(dāng)reshard_after_forward=True通信成本增加50%,類似ZeRO-DP-3,速度會(huì)慢,但是顯存開銷最小,此時(shí)行為如下:
FSDP forward pass:
for layer_i in layers:
all-gather full weights for layer_i
forward pass for layer_i
discard full weights for layer_i
FSDP backward pass:
for layer_i in layers:
all-gather full weights for layer_i
backward pass for layer_i
discard full weights for layer_i
reduce-scatter gradients for layer_i
FSDP通常情況下要比PyTorch DDP快,因?yàn)閛ptimizer step is sharded, 而且額外的通信可以和forward過程交叉; FSDP用8 GPUs可以訓(xùn)練13B parameter models,用128 GPUs可以訓(xùn)練175B parameter models。當(dāng)設(shè)置cpu_offload=True,可以用256 GPUs訓(xùn)練 1T parameter models。 FSDP只兼容pointwise Optimizers(Adam, AdamW, Adadelta, Adamax, SGD等),如果是non-pointwise Optimizers(Adagrad, Adafactor, LAMB等),sharding將得到稍微不一樣的結(jié)果。
使用FSDP很簡單,只需要在代碼中簡單地替換原來的DDP:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
...
sharded_module = DDP(my_module) -> FSDP(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
for sample, label in dataload.next_batch:
out = sharded_module(x=sample, y=3, z=torch.Tensor([1]))
loss = criterion(out, label)
loss.backward()
optim.step()
結(jié)語
未來,隨著算力的增強(qiáng),大模型應(yīng)該是趨勢,那么類似FSDP這樣的工具將會(huì)發(fā)揮價(jià)值。PS:本文只是簡單地回顧了FSDP背后所涉及的優(yōu)化邏輯,但是背后的實(shí)現(xiàn)細(xì)節(jié)應(yīng)該遠(yuǎn)不止此,如果錯(cuò)誤,請(qǐng)見解。
參考
Fully Sharded Data Parallel: faster AI training with fewer GPUs https://github.com/microsoft/DeepSpeed ZeRO: Memory Optimizations Toward Training Trillion Parameter Models Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training
推薦閱讀
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
"未來"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
一個(gè)用心的公眾號(hào)

