<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          如何用數(shù)據(jù)并行訓(xùn)練萬億參數(shù)模型?

          共 10955字,需瀏覽 22分鐘

           ·

          2021-08-25 03:05

          點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

          設(shè)為星標(biāo),干貨直達(dá)!

          近期,F(xiàn)acebook發(fā)布了FSDP(Fully Sharded Data Parallel),這個(gè)是對(duì)標(biāo)微軟在DeepSpeed中提出的ZeRO,F(xiàn)SDP可以看成PyTorch中的DDP優(yōu)化版本,本身也是數(shù)據(jù)并行,但是和DDP不同的是,F(xiàn)SDP采用了parameter sharding,所謂的parameter sharding就是將模型參數(shù)也切分到各個(gè)GPUs上,而DDP每個(gè)GPU都要保存一份parameter,F(xiàn)SDP可以實(shí)現(xiàn)更好的訓(xùn)練效率(速度和顯存使用)。這背后的優(yōu)化邏輯可以從谷歌和微軟的論文中找到。

          Sharding weight update

          對(duì)于典型的數(shù)據(jù)并行實(shí)現(xiàn)(PyTorch的DDP和TF的tf.distribute.MirroredStrategy)來說,每個(gè)replica(GPU)都擁有一份模型參數(shù)和一套o(hù)ptimizer,每個(gè)訓(xùn)練step,數(shù)據(jù)被均分到每個(gè)replica上,每個(gè)replica基于被分到的數(shù)據(jù)單獨(dú)計(jì)算自己的local gradients,然后所有的replicas基于all-reduce操作來得到local gradients的summed gradients,這樣每個(gè)replica其實(shí)都拿到了global gradients,最后基于global gradients更新模型參數(shù)(weight update)。這個(gè)過程如下圖所示:

          其中all-reduce操作(ring all-reduce)包含兩個(gè)操作:reduce-scatter和all-gather。在reduce-scatter階段,gradients被均分成不同的blocks或shards,通過N-1輪交換數(shù)據(jù),每個(gè)replica都得到一份reduced后的shards;在all-grather階段,通過N-1輪數(shù)據(jù)交換,每個(gè)replica都將自己的那份reduced后的shards廣播到其它的replicas,這樣所有的replicas就能得到全部reduced后的gradients。不論有多少replicas,all-reduce的通信成本上是恒定的,這樣就可以實(shí)現(xiàn)線性加速。每個(gè)replicas拿到reduced gradients后都在做重復(fù)的update weight,因?yàn)槊總€(gè)replicas都有模型參數(shù)的一個(gè)copy。如果模型(如NLP中的Transformer)比較大,參數(shù)量多,這個(gè)update weight在訓(xùn)練step中就會(huì)占據(jù)不可忽略的耗時(shí);對(duì)于小模型的大規(guī)模分布式訓(xùn)練,一般每個(gè)device會(huì)采用較小的batch size以防止global batch size過大,此時(shí)update weight也會(huì)成為訓(xùn)練step中的重要耗時(shí)項(xiàng)。為了解決這個(gè)問題,谷歌在2020年提出了sharding weight update,如下圖所示,經(jīng)過reduce-scatter后每個(gè)replica得到一個(gè)gradient shard,每個(gè)replica先更新自己的shard的weight,然后再進(jìn)行all-gather,這樣其實(shí)是和原始的all-reduce是等價(jià)的。但是經(jīng)過這個(gè)調(diào)整,每個(gè)replica只是update weight shard,耗時(shí)就會(huì)降低了,相當(dāng)于update weight也被各個(gè)replica給分擔(dān)了。

          另外一點(diǎn)就是要考慮optimizer,因?yàn)閛ptimizer往往包含額外的參數(shù),比如SGD包含一套參數(shù):gradient的EMA,而Adam包含兩套參數(shù):gradient的EMA和variance,這些參數(shù)可以統(tǒng)稱為optimizer states,它們也是需要同步更新的。當(dāng)模型參數(shù)較大時(shí),optimizer states也會(huì)很大,比如Adam就是模型參數(shù)的2倍,如果也對(duì)optimizer states進(jìn)行all-gather的話,通信成本就會(huì)比較大(原始的all-reduce并不需要)。optimizer states只參與weight update中,但是在下一個(gè)forward和backward中并不需要,不過optimizer states應(yīng)該被包含在模型的checkpoints中,因?yàn)樗鼈円彩莟raining state,比較好的方案是只有當(dāng)需要時(shí)才對(duì)optimizer states進(jìn)行all-gather,這就變成如下圖所示:這里optimizer的auxliary只在Loop body外面才進(jìn)行all-gather以得到final auxliary。另外左圖和右圖的區(qū)別是weight的all-gather的位置不同,左圖weight的all-gather是在update后立即進(jìn)行的,而右圖是在需要的時(shí)候(forward和backward)才進(jìn)行all-gather,看起來像是左邊的方案更好一點(diǎn),因?yàn)樵谧詈蟮玫絝inal weight時(shí)右圖還需要一次all-gather。但是右圖方案有更大的優(yōu)化空間,這是因?yàn)樵趂orward和backward過程中往往不需要高精度的weight,比如TPU中可以采用bfloat16,雖然update weight需要float32。在右圖方案中,可以采用低精度bfloat16來all-gather來得到所需要的全部weight,這樣就大大降低了內(nèi)存使用和通信成本。另外weight和auxliary weight的生存周期也減少了。特別是optimizer的auxliary weight,在training loop中其實(shí)只需要shard,這樣就節(jié)省一部分內(nèi)存空間,可以用來存儲(chǔ)forward和backward中activations和gradients。假定模型參數(shù)大小是W,而auxliary weight大小是V,共有N個(gè)shards,orward和backward中activations和gradients的峰值大小是P,那么訓(xùn)練的峰值大小就從W+V+P降低為max(W+V/N+P,W+V),這帶來的一個(gè)好處是Adam將和SGD一樣高效(Adam比SGD要多一份auxliary weight)。可以看到谷歌提出的sharding weight update不僅可以加速訓(xùn)練,而且也會(huì)節(jié)省顯存,這里只是簡單介紹了論文最核心的優(yōu)化邏輯,論文中還有關(guān)于graph和shard具體實(shí)現(xiàn)細(xì)節(jié)討論。論文中基于ResNet-50,Transformer和NCF三個(gè)模型做實(shí)驗(yàn),實(shí)驗(yàn)配置如下:從實(shí)驗(yàn)結(jié)果來看,無論是CV還是NLP模型在訓(xùn)練耗時(shí)和顯存使用上均有提升,特別是對(duì)大規(guī)模訓(xùn)練的場景(replica batch size?。┖湍P洼^大的場景(Transformer模型):

          ZeRO-DP


          微軟在DeepSpeed中提出的ZeRO(Zero Redundancy Optimizer)出發(fā)點(diǎn)是優(yōu)化內(nèi)存使用,從而提高訓(xùn)練速度,并且可以實(shí)現(xiàn)訓(xùn)練更大的模型。ZeRO包含模型并行ZeRO-R和數(shù)據(jù)并行ZeRO-DP,這里我們只討論數(shù)據(jù)并行ZeRO-DP。ZeRO-DP的出發(fā)點(diǎn)是優(yōu)化model states,這里的model states包括:optimizer states, gradients and parameters,其中optimizer states前面已經(jīng)說過,就是optimizer所需要的參數(shù),對(duì)于Adam其optimizer states是parameters的2倍,而且使用混合精度訓(xùn)練時(shí),optimizer states是fp32,這將成為顯存占用的大頭。

          在混合精度訓(xùn)練中,訓(xùn)練的forward和backward采用的是fp16 weights,activations和gradients,但是weight update需要采用fp32,這就需要optimizer保存一份fp32 weights,而且optimizer states也要采用fp32。假定模型大小是 Ψ,而gradients和parameters均采用fp16,那么消耗的顯存是2Ψ+2Ψ。而Adam需要fp32的parameters,momentum和variance(optimizer states),其消耗的顯存是4Ψ+4Ψ+4Ψ。用K來表示optimizer states的multiplier,那么model states消耗的顯存是(4+K)*Ψ,對(duì)于Adam來說K=12,那么model states消耗的顯存是16Ψ。ZeRO-DP的優(yōu)化策略就是分別對(duì)model states各個(gè)部分進(jìn)行partitioning:

          Optimizer State Partitioning

          如果DP的并行度為(replicas數(shù)量),那么可以將optimizer state均分為個(gè)partitions,這樣第i個(gè)節(jié)點(diǎn)只需要更新optimizer state第i個(gè)partition。此時(shí)每個(gè)節(jié)點(diǎn)只需要存儲(chǔ)和更新所有optimizer state的,而且也只更新parameter的。在每個(gè)training step的最后,只需要執(zhí)行all-gather,每個(gè)節(jié)點(diǎn)就可以獲得更新后的全部parameter??梢杂?jì)算,optimizer State partitioning()消耗的顯存就減少為。這個(gè)優(yōu)化其實(shí)前面谷歌的工作也做了。

          Gradient Partitioning

          既然每個(gè)節(jié)點(diǎn)只需要更新parameter的,那么其實(shí)每個(gè)節(jié)點(diǎn)也只需要對(duì)應(yīng)參數(shù)的gradient。具體地,在backward過程的每個(gè)layer,一旦得到了gradient,每個(gè)節(jié)點(diǎn)就對(duì)自己所需那部分參數(shù)的gradient做reduce(等價(jià)于做一個(gè)reduce-scatter),得到summed gradients,由于其它部分的gradient并不需要了就可以釋放了,從而減少了顯存使用,這可以稱為gradient partitioning()。此時(shí)顯存的消耗降為

          Parameter Partitioning

          更進(jìn)一步地,其實(shí)每個(gè)節(jié)點(diǎn)只需要存儲(chǔ)要更新的那部分參數(shù)就好,在forward和backward過程中,需要全部的weight時(shí)再進(jìn)行all-gather,然后再丟棄,這就是parameter partitioning(),此時(shí)顯存的消耗進(jìn)一步減低為。但是采用parameter partitioning是通信開銷的,論文中實(shí)驗(yàn)說明使用后通信成本增大1.5倍。

          基于ZeRO-DP,當(dāng)時(shí),1T Model(萬億參數(shù))消耗的顯存為15.6GB,模型可以放在一張32GB的V100卡上。其實(shí)可以看到,谷歌的sharding weight update近似等價(jià)于采用的ZeRO-DP,雖然兩個(gè)工作的出發(fā)點(diǎn)不一樣,但是殊途同歸。

          FSDP

          其實(shí)在FSDP之前,F(xiàn)acebook已經(jīng)實(shí)現(xiàn)了optimizer state+gradient sharding DP,這就是采用的ZeRO-DP,或者叫ZeRO-DP-2,這個(gè)實(shí)現(xiàn)包含在fairscale庫中,一個(gè)具體的使用case如下所示:

          import torch
          import torch.distributed as dist
          import torch.multiprocessing as mp
          from fairscale.optim.oss import OSS
          from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP

          def train(
              rank: int,
              world_size: int,
              epochs: int)
          :


              # DDP init example
              dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size)

              # Problem statement
              model = myAwesomeModel().to(rank)
              dataloader = mySuperFastDataloader()
              loss_fn = myVeryRelevantLoss()
              base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
              base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS

              # Wrap the optimizer in its state sharding brethren
              optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)

              # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
              model = ShardedDDP(model, optimizer)

              # Any relevant training loop, nothing specific to OSS. For example:
              model.train()
              for e in range(epochs):
                  for batch in dataloader:
                      # Train
                      model.zero_grad()
                      outputs = model(batch["inputs"])
                      loss = loss_fn(outputs, batch["label"])
                      loss.backward()
                      optimizer.step()

              dist.destroy_process_group()

          if __name__ == "__main__":
              # Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
              mp.spawn(
                  train,
                  args=(
                      WORLD_SIZE,
                      EPOCHS,
                  ),
                  nprocs=WORLD_SIZE,
                  join=True,
              )


          而最近發(fā)布的FSDP更是實(shí)現(xiàn)了完全的ZeRO-DP,而且據(jù)官方說效率更高,更重要的是FSDP可以直接替換PyTorch的DDP,F(xiàn)SDP的特點(diǎn)如下:

          • FSDP對(duì)parameters (FP16 + FP32)和optimizer state進(jìn)行sharding;
          • 當(dāng)reshard_after_forward=False,和PyTorch DDP通信成本一樣,類似ZeRO-DP-2;
          • 當(dāng)reshard_after_forward=True通信成本增加50%,類似ZeRO-DP-3,速度會(huì)慢,但是顯存開銷最小,此時(shí)行為如下:
          FSDP forward pass:
              for layer_i in layers:
                  all-gather full weights for layer_i
                  forward pass for layer_i
                  discard full weights for layer_i
          FSDP backward pass:
              for layer_i in layers:
                  all-gather full weights for layer_i
                  backward pass for layer_i
                  discard full weights for layer_i
                  reduce-scatter gradients for layer_i
          • FSDP通常情況下要比PyTorch DDP快,因?yàn)閛ptimizer step is sharded, 而且額外的通信可以和forward過程交叉;
          • FSDP用8 GPUs可以訓(xùn)練13B parameter models,用128 GPUs可以訓(xùn)練175B parameter models。當(dāng)設(shè)置cpu_offload=True,可以用256 GPUs訓(xùn)練 1T parameter models。
          • FSDP只兼容pointwise Optimizers(Adam, AdamW, Adadelta, Adamax, SGD等),如果是non-pointwise Optimizers(Adagrad, Adafactor, LAMB等),sharding將得到稍微不一樣的結(jié)果。

          使用FSDP很簡單,只需要在代碼中簡單地替換原來的DDP:

          from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
          ...
          sharded_module = DDP(my_module) -> FSDP(my_module)
          optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
          for sample, label in dataload.next_batch:
            out = sharded_module(x=sample, y=3, z=torch.Tensor([1]))
            loss = criterion(out, label)
            loss.backward()
            optim.step()


          結(jié)語

          未來,隨著算力的增強(qiáng),大模型應(yīng)該是趨勢,那么類似FSDP這樣的工具將會(huì)發(fā)揮價(jià)值。PS:本文只是簡單地回顧了FSDP背后所涉及的優(yōu)化邏輯,但是背后的實(shí)現(xiàn)細(xì)節(jié)應(yīng)該遠(yuǎn)不止此,如果錯(cuò)誤,請(qǐng)見解。

          參考

          • Fully Sharded Data Parallel: faster AI training with fewer GPUs
          • https://github.com/microsoft/DeepSpeed
          • ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
          • Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training



          推薦閱讀

          CPVT:一個(gè)卷積就可以隱式編碼位置信息

          SOTA模型Swin Transformer是如何煉成的!

          谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!

          BatchNorm的避坑指南(上)

          BatchNorm的避坑指南(下)

          目標(biāo)跟蹤入門篇-相關(guān)濾波

          SOTA模型Swin Transformer是如何煉成的!

          MoCo V3:我并不是你想的那樣!

          Transformer在語義分割上的應(yīng)用

          "未來"的經(jīng)典之作ViT:transformer is all you need!

          PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!

          漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA

          Transformer為何能闖入CV界秒殺CNN?

          不妨試試MoCo,來替換ImageNet上pretrain模型!


          機(jī)器學(xué)習(xí)算法工程師


                                              一個(gè)用心的公眾號(hào)


          瀏覽 74
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  俺去一在线三区 | 天天干夜夜一级黄色片 | 国产美女精品视频 | 欧美大屌操逼视频 | 日韩三级在线免费观看 |